23
23
from __future__ import division
24
24
from __future__ import print_function
25
25
26
+ import random
26
27
import re
27
28
28
29
# Dependency imports
@@ -137,12 +138,112 @@ def _rev_block_forward(x1,
137
138
if layer_scopes is not None :
138
139
layer_scopes .append (layer_vs )
139
140
out = _rev_layer_forward (
140
- out , f , g , f_side_input , g_side_input , gate_outputs = gate_outputs )
141
+ out ,
142
+ f [i ],
143
+ g [i ],
144
+ f_side_input ,
145
+ g_side_input ,
146
+ gate_outputs = gate_outputs )
141
147
142
148
y1 , y2 = out
143
149
return y1 , y2
144
150
145
151
152
+ def _underlying_variable (t ):
153
+ """Find the underlying variable ref, ignoring Identity ops."""
154
+ while t .op .type == "Identity" :
155
+ t = t .op .inputs [0 ]
156
+ if t .dtype == dtypes .float32_ref and "Variable" in t .op .type :
157
+ return t
158
+ else :
159
+ return None
160
+
161
+
162
+ def fn_with_custom_grad (grad_fn ):
163
+ """Decorator to create a subgraph with a custom gradient function.
164
+
165
+ The subgraph created by the decorated function is NOT put in a Defun and so
166
+ does not suffer from the limitations of the Defun (all subgraph ops on the
167
+ same device, no summaries).
168
+
169
+ Args:
170
+ grad_fn: function with signature
171
+ (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars),
172
+ all of which are lists of Tensors.
173
+
174
+ Returns:
175
+ Decorator for function such that the gradient is defined by grad_fn.
176
+ """
177
+
178
+ def dec (fn ):
179
+
180
+ def wrapped (* args ):
181
+ return _fn_with_custom_grad (fn , args , grad_fn )
182
+
183
+ return wrapped
184
+
185
+ return dec
186
+
187
+
188
+ def _fn_with_custom_grad (fn , inputs , grad_fn ):
189
+ """Create a subgraph with a custom gradient.
190
+
191
+ Args:
192
+ fn: function that takes inputs as arguments and produces 1 or more Tensors.
193
+ inputs: list<Tensor>, will be passed as fn(*inputs).
194
+ grad_fn: function with signature
195
+ (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars),
196
+ all of which are lists of Tensors.
197
+
198
+ Returns:
199
+ fn(*inputs)
200
+ """
201
+ with tf .variable_scope (None , default_name = "fn_with_custom_grad" ) as vs :
202
+ inputs = list (inputs )
203
+ outputs = fn (* inputs )
204
+ train_vars = list (vs .trainable_variables ())
205
+
206
+ if grad_fn is None :
207
+ return outputs
208
+ else :
209
+ if not (isinstance (outputs , tuple ) or isinstance (outputs , list )):
210
+ outputs = [outputs ]
211
+ outputs = list (outputs )
212
+
213
+ in_types = [t .dtype for t in inputs ]
214
+ out_types = [t .dtype for t in outputs ]
215
+ var_types = [t .dtype for t in train_vars ]
216
+
217
+ def custom_grad_fn (op , * dys ):
218
+ """Custom grad fn applying grad_fn for identity Defun."""
219
+ dys = list (dys )
220
+ fn_inputs = op .inputs [:len (inputs )]
221
+ fn_vars = op .inputs [len (inputs ):len (inputs ) + len (train_vars )]
222
+ fn_outputs = op .inputs [len (inputs ) + len (train_vars ):]
223
+ assert len (fn_outputs ) == len (outputs )
224
+ assert len (fn_outputs ) == len (dys )
225
+
226
+ grad_inputs , grad_vars = grad_fn (fn_inputs , fn_vars , fn_outputs , dys )
227
+ grad_outputs = [None ] * len (fn_outputs )
228
+ return tuple (grad_inputs + grad_vars + grad_outputs )
229
+
230
+ # The Defun takes as input the original inputs, the trainable variables
231
+ # created in fn, and the outputs. In the forward it passes through the
232
+ # outputs. In the backwards, it produces gradients for the original inputs
233
+ # and the trainable variables.
234
+ @function .Defun (
235
+ * (in_types + var_types + out_types ),
236
+ func_name = "identity_custom_grad%d" % random .randint (1 , 10 ** 9 ),
237
+ python_grad_func = custom_grad_fn ,
238
+ shape_func = lambda _ : [t .get_shape () for t in outputs ])
239
+ def identity (* args ):
240
+ outs = args [len (inputs ) + len (train_vars ):]
241
+ return tuple ([tf .identity (t ) for t in outs ])
242
+
243
+ id_out = identity (* (inputs + train_vars + outputs ))
244
+ return id_out
245
+
246
+
146
247
def rev_block (x1 ,
147
248
x2 ,
148
249
f ,
@@ -156,19 +257,29 @@ def rev_block(x1,
156
257
A reversible residual layer is defined as:
157
258
158
259
```
159
- y1 = x1 + f(x2)
160
- y2 = x2 + g(y1)
260
+ y1 = x1 + f(x2, f_side_input )
261
+ y2 = x2 + g(y1, g_side_input )
161
262
```
162
263
264
+ A reversible residual block, defined here, is a series of reversible residual
265
+ layers.
266
+
267
+ Limitations:
268
+ * f and g must not close over any Tensors; all side inputs to f and g should
269
+ be passed in with f_side_input and g_side_input which will be forwarded to
270
+ f and g.
271
+ * f and g must not change the dimensionality of their inputs in order for the
272
+ addition in the equations above to work.
273
+
163
274
Args:
164
275
x1: a float Tensor.
165
276
x2: a float Tensor.
166
- f: a function, (Tensor) -> (Tensor). Should not change the shape of the
167
- Tensor. Expected to create variables. See f_side_input if there are side
168
- inputs.
169
- g: a function, (Tensor) -> (Tensor). Should not change the shape of the
170
- Tensor. Expected to create variables. See g_side_input if there are side
171
- inputs.
277
+ f: a function, (Tensor) -> (Tensor) (or list of such of length num_layers).
278
+ Should not change the shape of the Tensor. Expected to create variables.
279
+ See f_side_input if there are side inputs.
280
+ g: a function, (Tensor) -> (Tensor) (or list of such of length num_layers).
281
+ Should not change the shape of the Tensor. Expected to create variables.
282
+ See g_side_input if there are side inputs.
172
283
num_layers: int, number of reversible residual layers. Each layer will
173
284
apply f and g according to the equations above, with new variables in each
174
285
layer.
@@ -185,46 +296,43 @@ def rev_block(x1,
185
296
f_side_input = []
186
297
if g_side_input is None :
187
298
g_side_input = []
299
+ if isinstance (f , list ):
300
+ assert len (f ) == num_layers
301
+ else :
302
+ f = [f ] * num_layers
303
+ if isinstance (g , list ):
304
+ assert len (g ) == num_layers
305
+ else :
306
+ g = [g ] * num_layers
188
307
308
+ # Filled by the forward function below
189
309
layer_scopes = []
190
310
191
- def rev_block_grad ( op , * grad_ys ):
311
+ def custom_grad_fn ( inputs , variables , ys , grad_ys ):
192
312
"""Custom gradient fn for a block of reversible residual layers."""
193
- ys = (op .outputs [0 ], op .outputs [1 ])
194
-
195
- # The Defun will have as inputs the main inputs (x1, x2), the variables
196
- # created inside f and g, and the side inputs to f and g. The order of the
197
- # grads returned from this function must match the order of the inputs.
198
- # The code here partitions the hoisted inputs into f variables, f side
199
- # inputs, g variables, and g side inputs and keeps track of their positions
200
- # in hoisted_inputs.
201
-
202
- hoisted_inputs = op .inputs [2 :]
203
- f_vars = [[] for _ in range (num_layers )]
204
- g_vars = [[] for _ in range (num_layers )]
205
- f_vars_idxs = [[] for _ in range (num_layers )]
206
- g_vars_idxs = [[] for _ in range (num_layers )]
313
+ side_inputs = inputs [2 :]
207
314
f_side_idxs = [None ] * len (f_side_input )
208
315
g_side_idxs = [None ] * len (g_side_input )
316
+ assert len (side_inputs ) == len (f_side_input ) + len (g_side_input )
209
317
210
- for t in f_side_input + g_side_input :
211
- assert t in hoisted_inputs
212
-
213
- for i , t in enumerate (hoisted_inputs ):
214
- # Side inputs
318
+ for i , t in enumerate (side_inputs ):
215
319
if t in f_side_input :
216
320
f_side_idxs [f_side_input .index (t )] = i
217
- continue
218
- if t in g_side_input :
321
+ elif t in g_side_input :
219
322
g_side_idxs [g_side_input .index (t )] = i
220
- continue
323
+ else :
324
+ assert False
221
325
222
- # Variables
223
- ref = t .op .inputs [0 ]
224
- assert ref .dtype == dtypes .float32_ref
326
+ f_vars = [[] for _ in range (num_layers )]
327
+ g_vars = [[] for _ in range (num_layers )]
328
+ f_vars_idxs = [[] for _ in range (num_layers )]
329
+ g_vars_idxs = [[] for _ in range (num_layers )]
330
+
331
+ for i , t in enumerate (variables ):
332
+ ref = _underlying_variable (t )
225
333
226
334
# Use the name to identify the layer number and function (f or g)
227
- regex = LAYER_RE .match (t .name )
335
+ regex = LAYER_RE .match (ref .name )
228
336
layer_no = int (regex .group (1 ))
229
337
fn_name = regex .group (2 )
230
338
if fn_name == "f" :
@@ -244,12 +352,15 @@ def rev_block_grad(op, *grad_ys):
244
352
layer_scopes .reverse ()
245
353
f_vars .reverse ()
246
354
g_vars .reverse ()
355
+ f .reverse ()
356
+ g .reverse ()
247
357
248
358
for i in xrange (num_layers ):
249
359
with tf .variable_scope (layer_scopes [i ], reuse = True ):
250
- ys , grad_ys , f_ret , g_ret = (_rev_layer_backward (
251
- ys , grad_ys , f , g , f_vars [i ], f_side_input , g_vars [i ],
252
- g_side_input ))
360
+
361
+ ys , grad_ys , f_ret , g_ret = _rev_layer_backward (ys , grad_ys , f [i ], g [i ],
362
+ f_vars [i ], f_side_input ,
363
+ g_vars [i ], g_side_input )
253
364
254
365
grad_f_vars , grad_f_side = f_ret
255
366
grad_g_vars , grad_g_side = g_ret
@@ -262,8 +373,9 @@ def rev_block_grad(op, *grad_ys):
262
373
acc_f_side_grads = _acc_grads (* f_side_grads )
263
374
acc_g_side_grads = _acc_grads (* g_side_grads )
264
375
265
- # Use the stored idxs to put gradients in the same order as hoisted_inputs.
266
- hoisted_inputs_grads = [None ] * len (hoisted_inputs )
376
+ # Use the stored idxs to put gradients in the passed-in order.
377
+ side_input_grads = [None ] * len (side_inputs )
378
+ variable_grads = [None ] * len (variables )
267
379
268
380
# Variable gradients were collected in reverse layer order. Reverse to match
269
381
# idxs.
@@ -272,43 +384,30 @@ def rev_block_grad(op, *grad_ys):
272
384
for idxs , grads in zip (f_vars_idxs , f_var_grads ) + zip (
273
385
g_vars_idxs , g_var_grads ):
274
386
for i , grad in zip (idxs , grads ):
275
- hoisted_inputs_grads [i ] = grad
387
+ variable_grads [i ] = grad
276
388
277
389
for i , grad in zip (f_side_idxs , acc_f_side_grads ):
278
- hoisted_inputs_grads [i ] = grad
390
+ side_input_grads [i ] = grad
279
391
for i , grad in zip (g_side_idxs , acc_g_side_grads ):
280
- hoisted_inputs_grads [i ] = grad
392
+ side_input_grads [i ] = grad
281
393
282
394
grad_x1 , grad_x2 = grad_ys
283
- return [grad_x1 , grad_x2 ] + hoisted_inputs_grads
284
-
285
- @function .Defun (
286
- tf .float32 ,
287
- tf .float32 ,
288
- python_grad_func = rev_block_grad ,
289
- shape_func = lambda _ : [x1 .get_shape (), x2 .get_shape ()])
290
- def rev_block_defun (inp1 , inp2 ):
291
- inp1 .set_shape (x1 .get_shape ())
292
- inp2 .set_shape (x2 .get_shape ())
293
- return _rev_block_forward (
294
- inp1 ,
295
- inp2 ,
296
- f ,
297
- g ,
298
- num_layers = num_layers ,
299
- f_side_input = f_side_input ,
300
- g_side_input = g_side_input ,
301
- layer_scopes = layer_scopes ,
302
- gate_outputs = True )
395
+ return [grad_x1 , grad_x2 ] + side_input_grads , variable_grads
303
396
304
- if is_training :
305
- return rev_block_defun (x1 , x2 )
306
- else :
397
+ # Need a forward function with positional arguments
398
+ @fn_with_custom_grad (custom_grad_fn if is_training else None )
399
+ def forward (x1 , x2 , * side_inputs ):
400
+ f_side = side_inputs [:len (f_side_input )]
401
+ g_side = side_inputs [len (f_side_input ):]
307
402
return _rev_block_forward (
308
403
x1 ,
309
404
x2 ,
310
405
f ,
311
406
g ,
312
407
num_layers = num_layers ,
313
- f_side_input = f_side_input ,
314
- g_side_input = g_side_input )
408
+ f_side_input = f_side ,
409
+ g_side_input = g_side ,
410
+ layer_scopes = layer_scopes ,
411
+ gate_outputs = is_training )
412
+
413
+ return forward (x1 , x2 , * (f_side_input + g_side_input ))
0 commit comments