Skip to content

Commit

Permalink
Move autoregressive_sample and related test to decoding module.
Browse files Browse the repository at this point in the history
Correct Transformer model and test fast decoding from Transformer.
Unluckily, this changes Transformer checkpoint format (as positional embeddings aren't shared).

PiperOrigin-RevId: 321124488
  • Loading branch information
Lukasz Kaiser authored and copybara-github committed Jul 14, 2020
1 parent a3e159a commit fe0fa78
Show file tree
Hide file tree
Showing 11 changed files with 318 additions and 215 deletions.
31 changes: 29 additions & 2 deletions trax/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,25 +344,52 @@ def weights(self):
- an empty tuple
- a tensor (ndarray)
- a nested structure of tuples and tensors
If the layer has sublayers, the weights by convention will usually be
a tuple of length `len(sublayers)` containing the weights of sublayers.
"""
return self._weights

@weights.setter
def weights(self, weights):
"""Sets the weights of this layer and its sublayers."""
if isinstance(weights, dict) and weights == GET_WEIGHTS_FROM_CACHE:
return
self._weights = weights
# Set sublayer weights.
if self.sublayers:
n_layers = len(self.sublayers)
if len(weights) != n_layers:
raise ValueError(
f'Number of weight elements ({len(weights)}) does not equal the '
f'number of sublayers ({n_layers}) in: {str(self)}.')
for sublayer, sublayer_weights in zip(self.sublayers, weights):
sublayer.weights = sublayer_weights

@property
def state(self):
"""Returns a tuple containing this layer's state; may be empty."""
"""Returns a tuple containing this layer's state; may be empty.
If the layer has sublayers, the state by convention will usually be
a tuple of length `len(sublayers)` containing sublayer states.
"""
return self._state

@state.setter
def state(self, state):
if isinstance(state, dict) and state != GET_STATE_FROM_CACHE:
"""Sets the state of this layer and its sublayers."""
if isinstance(state, dict) and state == GET_STATE_FROM_CACHE:
return
self._state = state
# Set sublayer states.
if self.sublayers:
n_layers = len(self.sublayers)
if len(state) != n_layers:
raise ValueError(
f'Number of state elements ({len(state)}) does not equal the '
f'number of sublayers ({n_layers}) in: {str(self)}.')
for sublayer, sublayer_state in zip(self.sublayers, state):
sublayer.state = sublayer_state

def weights_and_state_signature(self, input_signature):
"""Return a pair containing the signatures of weights and state."""
Expand Down
98 changes: 30 additions & 68 deletions trax/layers/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,36 +111,6 @@ def init_weights_and_state(self, input_signature):
self._weights = weights
# pylint: enable=protected-access

@base.Layer.weights.setter
def weights(self, weights):
"""Recursively sets weights on this layer and all sublayers."""
if isinstance(weights, dict) and weights == base.GET_WEIGHTS_FROM_CACHE:
return
self._weights = weights
n_layers = self._n_layers
if len(weights) != n_layers:
raise ValueError(
f'Number of weight elements ({len(weights)}) does not equal '
f'number of sublayers ({n_layers}).')
for layer, sublayer_weights in zip(self.sublayers, weights):
if sublayer_weights is not base.GET_WEIGHTS_FROM_CACHE:
layer.weights = sublayer_weights

@base.Layer.state.setter
def state(self, state):
"""Recursively sets non-param state on this layer and all sublayers."""
if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE:
return
self._state = state
n_layers = self._n_layers
if n_layers != 1 and len(state) != n_layers:
raise ValueError(
f'Number of state elements ({len(state)}) does not equal '
f'number of sublayers ({n_layers}).')
for layer, sublayer_state in zip(self.sublayers, state):
if sublayer_state is not base.GET_STATE_FROM_CACHE:
layer.state = sublayer_state

def _n_inputs_n_outputs(self, layers):
del self
running_max = 0
Expand Down Expand Up @@ -256,26 +226,6 @@ def init_weights_and_state(self, input_signature):
self._state = state
self._weights = weights

@base.Layer.weights.setter
def weights(self, weights):
"""Recursively sets weights on this layer and all sublayers."""
if isinstance(weights, dict) and weights == base.GET_WEIGHTS_FROM_CACHE:
return
self._weights = weights
assert len(weights) == self._n_layers
for layer, sublayer_weights in zip(self.sublayers, weights):
layer.weights = sublayer_weights

@base.Layer.state.setter
def state(self, state):
"""Recursively sets non-param state on this layer and all sublayers."""
if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE:
return
self._state = state
assert len(state) == self._n_layers
for layer, sublayer_state in zip(self.sublayers, state):
layer.state = sublayer_state

def _validate(self, layers):
if not layers or len(layers) < 2:
raise ValueError(
Expand Down Expand Up @@ -381,14 +331,16 @@ def __init__(self, layer, axis=0, n_carry=1, remat=False):
self._n_carry = n_carry
self._axis = axis
self._remat = remat
self._weights = (base.EMPTY_WEIGHTS,)
self._state = (base.EMPTY_STATE,)

@property
def sublayer(self):
"""Returns the unique sublayer managed by this layer."""
return self._sublayers[0]

def forward(self, inputs):
weights = self.weights
weights = self.weights[0]
if isinstance(inputs, list):
inputs = tuple(inputs) # so that inputs structure matches outputs
n_carry = self._n_carry
Expand All @@ -404,13 +356,13 @@ def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name

if n_carry > 0:
xs = inputs[:-n_carry] # Split input stack into inputs and carry.
init = (inputs[-n_carry:], self.state)
init = (inputs[-n_carry:], self.state[0])
else:
xs, init = inputs, ([], self.state)
xs, init = inputs, ([], self.state[0])
ys, (carry, new_state) = fastmath.scan(scannable_fn, xs, init,
axis=self._axis, remat=self._remat)
res = ys + carry if n_carry > 0 else ys
self.state = new_state
self.state = (new_state,)
return res # Put outputs and carry back on stack.

def init_weights_and_state(self, input_signature):
Expand All @@ -424,17 +376,17 @@ def init_weights_and_state(self, input_signature):
layer_sig = ShapeDtype(_shape_without_axis(input_signature, self._axis),
input_signature.dtype)
weights, state = self.sublayer.init(layer_sig)
self._state = state
self._weights = weights
self._state = (state,)
self._weights = (weights,)
else:
xs = input_signature[:-n_carry]
init = input_signature[-n_carry:]
xs_slices = [ShapeDtype(_shape_without_axis(x, self._axis), x.dtype)
for x in xs]
layer_signature = tuple(xs_slices + list(init))
weights, state = self.sublayer.init(layer_signature, use_cache=True)
self._state = state
self._weights = weights
self._state = (state,)
self._weights = (weights,)


def Branch(*layers, name='Branch'):
Expand Down Expand Up @@ -652,8 +604,21 @@ def sublayer(self):
"""Returns the unique sublayer managed by this layer."""
return self._sublayers[0]

@base.Layer.state.setter
def state(self, state):
"""Recursively sets state on this layer and all sublayers."""
if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE:
return
self._state = state
self.sublayer.state = state[1]

def init_weights_and_state(self, input_signature):
weights, layer_state = self.sublayer.init(input_signature, use_cache=True)
self.state = ((), layer_state)
self._weights = (weights,)

def forward(self, inputs):
state, weights = self.state, self.weights
state, weights = self.state, self.weights[0]
if state[0] is (): # pylint: disable=literal-comparison
res, layer_state = self.sublayer.pure_fn(
inputs, weights, state[1], self.rng)
Expand All @@ -662,11 +627,6 @@ def forward(self, inputs):
else:
return state[0]

def init_weights_and_state(self, input_signature):
weights, layer_state = self.sublayer.init(input_signature, use_cache=True)
self.state = ((), layer_state)
self._weights = weights


class BatchLeadingAxes(base.Layer):
"""Applies a layer after flattening all but n_last_axes_to_keep to batch.
Expand All @@ -684,6 +644,8 @@ def __init__(self, layer, n_last_axes_to_keep=1):
super(BatchLeadingAxes, self).__init__(n_in=layer.n_in, n_out=layer.n_out)
self._sublayers = [layer]
self._n_last_axes_to_keep = n_last_axes_to_keep
self._weights = (base.EMPTY_WEIGHTS,)
self._state = (base.EMPTY_STATE,)

@property
def sublayer(self):
Expand All @@ -695,14 +657,14 @@ def forward(self, inputs):
batched_shape = [-1] + list(inputs.shape[-self._n_last_axes_to_keep:])
inputs = jnp.reshape(inputs, batched_shape)
res, layer_state = self.sublayer.pure_fn(
inputs, self.weights, self.state, self.rng)
self.state = layer_state
inputs, self.weights[0], self.state[0], self.rng)
self.state = (layer_state,)
return jnp.reshape(res, batched_axes_shape + list(res.shape[1:]))

def init_weights_and_state(self, input_signature):
weights, layer_state = self.sublayer.init(input_signature, use_cache=True)
self.state = layer_state
self._weights = weights
self.state = (layer_state,)
self._weights = (weights,)


# All module-private helper functions are below.
Expand Down
5 changes: 4 additions & 1 deletion trax/layers/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def forward(self, inputs):

def init_weights_and_state(self, input_signature):
# LSTM state last dimension must be twice n_units.
assert input_signature[1].shape[-1] == 2 * self._n_units
if input_signature[1].shape[-1] != 2 * self._n_units:
raise ValueError(
f'Last dimension of state (shape: {str(input_signature[1].shape)}) '
f'must be equal to 2*n_units ({2 * self._n_units})')
# The dense layer input is the input and half of the lstm state.
input_shape = input_signature[0].shape[-1] + self._n_units
rng1, rng2 = fastmath.random.split(self.rng, 2)
Expand Down
18 changes: 17 additions & 1 deletion trax/models/reformer/reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ def __init__(self, *residual_layers, attention_layer=None):
running_total -= layer.n_out
self._n_in = self._n_out = running_max + 1

# TODO(lukaszkaiser): these setters should not be needed, why do they
# cause the e2e test to fail? Investigate and remove these setters.
@tl.Layer.weights.setter
def weights(self, weights):
"""Sets the weights of this layer and its sublayers."""
if isinstance(weights, dict) and weights == tl.GET_WEIGHTS_FROM_CACHE:
return
self._weights = weights

@tl.Layer.state.setter
def state(self, state):
"""Sets the state of this layer and its sublayers."""
if isinstance(state, dict) and state == tl.GET_STATE_FROM_CACHE:
return
self._state = state

def forward(self, xs):
rngs = _split_rngs(self.rng, len(self.sublayers))
accumulator, *context = xs
Expand All @@ -120,7 +136,7 @@ def forward(self, xs):

output = accumulator + residual
stack = (output,) + context
self.state = new_state
self.state = tuple(new_state)
return stack

def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
Expand Down
3 changes: 1 addition & 2 deletions trax/models/research/skipping_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def state(self, state):
f'Number of state elements ({len(state[1])}) does not equal '
f'number of sublayers ({n_layers}).')
for layer, sublayer_state in zip(self.sublayers, state[1]):
if sublayer_state is not tl.GET_STATE_FROM_CACHE:
layer.state = sublayer_state
layer.state = sublayer_state

def forward(self, xs):
self._validate_forward_inputs(xs)
Expand Down
24 changes: 17 additions & 7 deletions trax/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,26 @@ def Transformer(input_vocab_size,
A Transformer model as a layer that maps from a source, target pair to
activations over a vocab set.
"""
def PositionalEncoder(vocab_size): # tokens --> vectors
def Embedder(vocab_size): # tokens --> vectors
return [
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
tl.PositionalEncoding(max_len=max_len),
]

in_encoder = PositionalEncoder(input_vocab_size)
out_encoder = (in_encoder if output_vocab_size is None
else PositionalEncoder(output_vocab_size))
in_embedder = Embedder(input_vocab_size)
out_embedder = (in_embedder if output_vocab_size is None
else Embedder(output_vocab_size))

# Positional encoding are not shared between encoder and decoder.
# Since encoder doesn't run stepwise, we do not use predict mode there.
encoder_mode = 'eval' if mode == 'predict' else mode
in_encoder = in_embedder + [
tl.PositionalEncoding(max_len=max_len, mode=encoder_mode)
]
out_encoder = out_embedder + [
tl.PositionalEncoding(max_len=max_len, mode=mode)
]

if output_vocab_size is None:
output_vocab_size = input_vocab_size

Expand Down Expand Up @@ -264,7 +274,7 @@ def PositionalEncoder(vocab_size): # tokens --> vectors

# Decode.
tl.Select([2, 1, 0]), # tok_d masks vec_e .....
tl.ShiftRight(), # tok_d ..... ..... .....
tl.ShiftRight(mode=mode), # tok_d ..... ..... .....
out_encoder, # vec_d ..... ..... .....
tl.Branch(
[], tl.EncoderDecoderMask()), # vec_d masks ..... .....
Expand Down Expand Up @@ -361,7 +371,7 @@ def PositionalEncoder(vocab_size): # tokens --> vectors

# Decode.
tl.Select([3, 1, 0, 2]), # tok_d vec_e mask_e tok_e tok_d
tl.ShiftRight(), # stok_d vec_e mask_e tok_e tok_d
tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d
tl.Branch(
[],
_MaskOfRightShiftedArray()
Expand Down
Loading

0 comments on commit fe0fa78

Please sign in to comment.