Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f5d5405

Browse files
author
Ryan Sepassi
committed
Update implementation of rev_block to use new fn_with_custom_grad (which limits usage of Defun)
PiperOrigin-RevId: 165525242
1 parent 3e295e7 commit f5d5405

File tree

3 files changed

+320
-144
lines changed

3 files changed

+320
-144
lines changed

tensor2tensor/layers/rev_block.py

Lines changed: 168 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from __future__ import division
2424
from __future__ import print_function
2525

26+
import random
2627
import re
2728

2829
# Dependency imports
@@ -137,12 +138,112 @@ def _rev_block_forward(x1,
137138
if layer_scopes is not None:
138139
layer_scopes.append(layer_vs)
139140
out = _rev_layer_forward(
140-
out, f, g, f_side_input, g_side_input, gate_outputs=gate_outputs)
141+
out,
142+
f[i],
143+
g[i],
144+
f_side_input,
145+
g_side_input,
146+
gate_outputs=gate_outputs)
141147

142148
y1, y2 = out
143149
return y1, y2
144150

145151

152+
def _underlying_variable(t):
153+
"""Find the underlying variable ref, ignoring Identity ops."""
154+
while t.op.type == "Identity":
155+
t = t.op.inputs[0]
156+
if t.dtype == dtypes.float32_ref and "Variable" in t.op.type:
157+
return t
158+
else:
159+
return None
160+
161+
162+
def fn_with_custom_grad(grad_fn):
163+
"""Decorator to create a subgraph with a custom gradient function.
164+
165+
The subgraph created by the decorated function is NOT put in a Defun and so
166+
does not suffer from the limitations of the Defun (all subgraph ops on the
167+
same device, no summaries).
168+
169+
Args:
170+
grad_fn: function with signature
171+
(inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars),
172+
all of which are lists of Tensors.
173+
174+
Returns:
175+
Decorator for function such that the gradient is defined by grad_fn.
176+
"""
177+
178+
def dec(fn):
179+
180+
def wrapped(*args):
181+
return _fn_with_custom_grad(fn, args, grad_fn)
182+
183+
return wrapped
184+
185+
return dec
186+
187+
188+
def _fn_with_custom_grad(fn, inputs, grad_fn):
189+
"""Create a subgraph with a custom gradient.
190+
191+
Args:
192+
fn: function that takes inputs as arguments and produces 1 or more Tensors.
193+
inputs: list<Tensor>, will be passed as fn(*inputs).
194+
grad_fn: function with signature
195+
(inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars),
196+
all of which are lists of Tensors.
197+
198+
Returns:
199+
fn(*inputs)
200+
"""
201+
with tf.variable_scope(None, default_name="fn_with_custom_grad") as vs:
202+
inputs = list(inputs)
203+
outputs = fn(*inputs)
204+
train_vars = list(vs.trainable_variables())
205+
206+
if grad_fn is None:
207+
return outputs
208+
else:
209+
if not (isinstance(outputs, tuple) or isinstance(outputs, list)):
210+
outputs = [outputs]
211+
outputs = list(outputs)
212+
213+
in_types = [t.dtype for t in inputs]
214+
out_types = [t.dtype for t in outputs]
215+
var_types = [t.dtype for t in train_vars]
216+
217+
def custom_grad_fn(op, *dys):
218+
"""Custom grad fn applying grad_fn for identity Defun."""
219+
dys = list(dys)
220+
fn_inputs = op.inputs[:len(inputs)]
221+
fn_vars = op.inputs[len(inputs):len(inputs) + len(train_vars)]
222+
fn_outputs = op.inputs[len(inputs) + len(train_vars):]
223+
assert len(fn_outputs) == len(outputs)
224+
assert len(fn_outputs) == len(dys)
225+
226+
grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys)
227+
grad_outputs = [None] * len(fn_outputs)
228+
return tuple(grad_inputs + grad_vars + grad_outputs)
229+
230+
# The Defun takes as input the original inputs, the trainable variables
231+
# created in fn, and the outputs. In the forward it passes through the
232+
# outputs. In the backwards, it produces gradients for the original inputs
233+
# and the trainable variables.
234+
@function.Defun(
235+
*(in_types + var_types + out_types),
236+
func_name="identity_custom_grad%d" % random.randint(1, 10**9),
237+
python_grad_func=custom_grad_fn,
238+
shape_func=lambda _: [t.get_shape() for t in outputs])
239+
def identity(*args):
240+
outs = args[len(inputs) + len(train_vars):]
241+
return tuple([tf.identity(t) for t in outs])
242+
243+
id_out = identity(*(inputs + train_vars + outputs))
244+
return id_out
245+
246+
146247
def rev_block(x1,
147248
x2,
148249
f,
@@ -156,19 +257,29 @@ def rev_block(x1,
156257
A reversible residual layer is defined as:
157258
158259
```
159-
y1 = x1 + f(x2)
160-
y2 = x2 + g(y1)
260+
y1 = x1 + f(x2, f_side_input)
261+
y2 = x2 + g(y1, g_side_input)
161262
```
162263
264+
A reversible residual block, defined here, is a series of reversible residual
265+
layers.
266+
267+
Limitations:
268+
* f and g must not close over any Tensors; all side inputs to f and g should
269+
be passed in with f_side_input and g_side_input which will be forwarded to
270+
f and g.
271+
* f and g must not change the dimensionality of their inputs in order for the
272+
addition in the equations above to work.
273+
163274
Args:
164275
x1: a float Tensor.
165276
x2: a float Tensor.
166-
f: a function, (Tensor) -> (Tensor). Should not change the shape of the
167-
Tensor. Expected to create variables. See f_side_input if there are side
168-
inputs.
169-
g: a function, (Tensor) -> (Tensor). Should not change the shape of the
170-
Tensor. Expected to create variables. See g_side_input if there are side
171-
inputs.
277+
f: a function, (Tensor) -> (Tensor) (or list of such of length num_layers).
278+
Should not change the shape of the Tensor. Expected to create variables.
279+
See f_side_input if there are side inputs.
280+
g: a function, (Tensor) -> (Tensor) (or list of such of length num_layers).
281+
Should not change the shape of the Tensor. Expected to create variables.
282+
See g_side_input if there are side inputs.
172283
num_layers: int, number of reversible residual layers. Each layer will
173284
apply f and g according to the equations above, with new variables in each
174285
layer.
@@ -185,46 +296,43 @@ def rev_block(x1,
185296
f_side_input = []
186297
if g_side_input is None:
187298
g_side_input = []
299+
if isinstance(f, list):
300+
assert len(f) == num_layers
301+
else:
302+
f = [f] * num_layers
303+
if isinstance(g, list):
304+
assert len(g) == num_layers
305+
else:
306+
g = [g] * num_layers
188307

308+
# Filled by the forward function below
189309
layer_scopes = []
190310

191-
def rev_block_grad(op, *grad_ys):
311+
def custom_grad_fn(inputs, variables, ys, grad_ys):
192312
"""Custom gradient fn for a block of reversible residual layers."""
193-
ys = (op.outputs[0], op.outputs[1])
194-
195-
# The Defun will have as inputs the main inputs (x1, x2), the variables
196-
# created inside f and g, and the side inputs to f and g. The order of the
197-
# grads returned from this function must match the order of the inputs.
198-
# The code here partitions the hoisted inputs into f variables, f side
199-
# inputs, g variables, and g side inputs and keeps track of their positions
200-
# in hoisted_inputs.
201-
202-
hoisted_inputs = op.inputs[2:]
203-
f_vars = [[] for _ in range(num_layers)]
204-
g_vars = [[] for _ in range(num_layers)]
205-
f_vars_idxs = [[] for _ in range(num_layers)]
206-
g_vars_idxs = [[] for _ in range(num_layers)]
313+
side_inputs = inputs[2:]
207314
f_side_idxs = [None] * len(f_side_input)
208315
g_side_idxs = [None] * len(g_side_input)
316+
assert len(side_inputs) == len(f_side_input) + len(g_side_input)
209317

210-
for t in f_side_input + g_side_input:
211-
assert t in hoisted_inputs
212-
213-
for i, t in enumerate(hoisted_inputs):
214-
# Side inputs
318+
for i, t in enumerate(side_inputs):
215319
if t in f_side_input:
216320
f_side_idxs[f_side_input.index(t)] = i
217-
continue
218-
if t in g_side_input:
321+
elif t in g_side_input:
219322
g_side_idxs[g_side_input.index(t)] = i
220-
continue
323+
else:
324+
assert False
221325

222-
# Variables
223-
ref = t.op.inputs[0]
224-
assert ref.dtype == dtypes.float32_ref
326+
f_vars = [[] for _ in range(num_layers)]
327+
g_vars = [[] for _ in range(num_layers)]
328+
f_vars_idxs = [[] for _ in range(num_layers)]
329+
g_vars_idxs = [[] for _ in range(num_layers)]
330+
331+
for i, t in enumerate(variables):
332+
ref = _underlying_variable(t)
225333

226334
# Use the name to identify the layer number and function (f or g)
227-
regex = LAYER_RE.match(t.name)
335+
regex = LAYER_RE.match(ref.name)
228336
layer_no = int(regex.group(1))
229337
fn_name = regex.group(2)
230338
if fn_name == "f":
@@ -244,12 +352,15 @@ def rev_block_grad(op, *grad_ys):
244352
layer_scopes.reverse()
245353
f_vars.reverse()
246354
g_vars.reverse()
355+
f.reverse()
356+
g.reverse()
247357

248358
for i in xrange(num_layers):
249359
with tf.variable_scope(layer_scopes[i], reuse=True):
250-
ys, grad_ys, f_ret, g_ret = (_rev_layer_backward(
251-
ys, grad_ys, f, g, f_vars[i], f_side_input, g_vars[i],
252-
g_side_input))
360+
361+
ys, grad_ys, f_ret, g_ret = _rev_layer_backward(ys, grad_ys, f[i], g[i],
362+
f_vars[i], f_side_input,
363+
g_vars[i], g_side_input)
253364

254365
grad_f_vars, grad_f_side = f_ret
255366
grad_g_vars, grad_g_side = g_ret
@@ -262,8 +373,9 @@ def rev_block_grad(op, *grad_ys):
262373
acc_f_side_grads = _acc_grads(*f_side_grads)
263374
acc_g_side_grads = _acc_grads(*g_side_grads)
264375

265-
# Use the stored idxs to put gradients in the same order as hoisted_inputs.
266-
hoisted_inputs_grads = [None] * len(hoisted_inputs)
376+
# Use the stored idxs to put gradients in the passed-in order.
377+
side_input_grads = [None] * len(side_inputs)
378+
variable_grads = [None] * len(variables)
267379

268380
# Variable gradients were collected in reverse layer order. Reverse to match
269381
# idxs.
@@ -272,43 +384,30 @@ def rev_block_grad(op, *grad_ys):
272384
for idxs, grads in zip(f_vars_idxs, f_var_grads) + zip(
273385
g_vars_idxs, g_var_grads):
274386
for i, grad in zip(idxs, grads):
275-
hoisted_inputs_grads[i] = grad
387+
variable_grads[i] = grad
276388

277389
for i, grad in zip(f_side_idxs, acc_f_side_grads):
278-
hoisted_inputs_grads[i] = grad
390+
side_input_grads[i] = grad
279391
for i, grad in zip(g_side_idxs, acc_g_side_grads):
280-
hoisted_inputs_grads[i] = grad
392+
side_input_grads[i] = grad
281393

282394
grad_x1, grad_x2 = grad_ys
283-
return [grad_x1, grad_x2] + hoisted_inputs_grads
284-
285-
@function.Defun(
286-
tf.float32,
287-
tf.float32,
288-
python_grad_func=rev_block_grad,
289-
shape_func=lambda _: [x1.get_shape(), x2.get_shape()])
290-
def rev_block_defun(inp1, inp2):
291-
inp1.set_shape(x1.get_shape())
292-
inp2.set_shape(x2.get_shape())
293-
return _rev_block_forward(
294-
inp1,
295-
inp2,
296-
f,
297-
g,
298-
num_layers=num_layers,
299-
f_side_input=f_side_input,
300-
g_side_input=g_side_input,
301-
layer_scopes=layer_scopes,
302-
gate_outputs=True)
395+
return [grad_x1, grad_x2] + side_input_grads, variable_grads
303396

304-
if is_training:
305-
return rev_block_defun(x1, x2)
306-
else:
397+
# Need a forward function with positional arguments
398+
@fn_with_custom_grad(custom_grad_fn if is_training else None)
399+
def forward(x1, x2, *side_inputs):
400+
f_side = side_inputs[:len(f_side_input)]
401+
g_side = side_inputs[len(f_side_input):]
307402
return _rev_block_forward(
308403
x1,
309404
x2,
310405
f,
311406
g,
312407
num_layers=num_layers,
313-
f_side_input=f_side_input,
314-
g_side_input=g_side_input)
408+
f_side_input=f_side,
409+
g_side_input=g_side,
410+
layer_scopes=layer_scopes,
411+
gate_outputs=is_training)
412+
413+
return forward(x1, x2, *(f_side_input + g_side_input))

0 commit comments

Comments
 (0)