🏷️sec_gru
As RNNs and particularly the LSTM architecture (:numref:sec_lstm
)
rapidly gained popularity during the 2010s,
a number of papers began to experiment
with simplified architectures in hopes
of retaining the key idea of incorporating
an internal state and multiplicative gating mechanisms
but with the aim of speeding up computation.
The gated recurrent unit (GRU) :cite:Cho.Van-Merrienboer.Bahdanau.ea.2014
offered a streamlined version of the LSTM memory cell
that often achieves comparable performance
but with the advantage of being faster
to compute :cite:Chung.Gulcehre.Cho.ea.2014
.
%load_ext d2lbook.tab
tab.interact_select(['mxnet', 'pytorch', 'tensorflow', 'jax'])
%%tab mxnet
from d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import rnn
npx.set_np()
%%tab pytorch
from d2l import torch as d2l
import torch
from torch import nn
%%tab tensorflow
from d2l import tensorflow as d2l
import tensorflow as tf
%%tab jax
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
Here, the LSTM's three gates are replaced by two:
the reset gate and the update gate.
As with LSTMs, these gates are given sigmoid activations,
forcing their values to lie in the interval fig_gru_1
illustrates the inputs for both
the reset and update gates in a GRU,
given the input of the current time step
and the hidden state of the previous time step.
The outputs of two gates are given
by two fully connected layers
with a sigmoid activation function.
Mathematically, for a given time step
$$ \begin{aligned} \mathbf{R}t = \sigma(\mathbf{X}t \mathbf{W}{xr} + \mathbf{H}{t-1} \mathbf{W}_{hr} + \mathbf{b}r),\ \mathbf{Z}t = \sigma(\mathbf{X}t \mathbf{W}{xz} + \mathbf{H}{t-1} \mathbf{W}{hz} + \mathbf{b}_z), \end{aligned} $$
where $\mathbf{W}{xr}, \mathbf{W}{xz} \in \mathbb{R}^{d \times h}$
and $\mathbf{W}{hr}, \mathbf{W}{hz} \in \mathbb{R}^{h \times h}$
are weight parameters and
Candidate Hidden State
Next, we integrate the reset gate rnn_h_with_state
,
leading to the following
candidate hidden state
$$\tilde{\mathbf{H}}t = \tanh(\mathbf{X}t \mathbf{W}{xh} + \left(\mathbf{R}t \odot \mathbf{H}{t-1}\right) \mathbf{W}{hh} + \mathbf{b}_h),$$
:eqlabel:gru_tilde_H
where $\mathbf{W}{xh} \in \mathbb{R}^{d \times h}$ and $\mathbf{W}{hh} \in \mathbb{R}^{h \times h}$
are weight parameters,
The result is a candidate, since we still need
to incorporate the action of the update gate.
Comparing with :eqref:rnn_h_with_state
,
now the influence of the previous states
can be reduced with the
elementwise multiplication of
$\mathbf{R}t$ and $\mathbf{H}{t-1}$
in :eqref:gru_tilde_H
.
Whenever the entries in the reset gate rnn_h_with_state
.
For all entries of the reset gate
:numref:fig_gru_2
illustrates the computational flow after applying the reset gate.
Hidden State
Finally, we need to incorporate the effect of the update gate
$$\mathbf{H}_t = \mathbf{Z}t \odot \mathbf{H}{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.$$
Whenever the update gate fig_gru_3
illustrates the computational flow after the update gate is in action.
In summary, GRUs have the following two distinguishing features:
- Reset gates help capture short-term dependencies in sequences.
- Update gates help capture long-term dependencies in sequences.
To gain a better understanding of the GRU model, let's implement it from scratch.
The first step is to initialize the model parameters.
We draw the weights from a Gaussian distribution
with standard deviation to be sigma
and set the bias to 0.
The hyperparameter num_hiddens
defines the number of hidden units.
We instantiate all weights and biases relating to the update gate,
the reset gate, and the candidate hidden state.
%%tab pytorch, mxnet, tensorflow
class GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
if tab.selected('mxnet'):
init_weight = lambda *shape: d2l.randn(*shape) * sigma
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
d2l.zeros(num_hiddens))
if tab.selected('pytorch'):
init_weight = lambda *shape: nn.Parameter(d2l.randn(*shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
nn.Parameter(d2l.zeros(num_hiddens)))
if tab.selected('tensorflow'):
init_weight = lambda *shape: tf.Variable(d2l.normal(shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
tf.Variable(d2l.zeros(num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state
%%tab jax
class GRUScratch(d2l.Module):
num_inputs: int
num_hiddens: int
sigma: float = 0.01
def setup(self):
init_weight = lambda name, shape: self.param(name,
nn.initializers.normal(self.sigma),
shape)
triple = lambda name : (
init_weight(f'W_x{name}', (self.num_inputs, self.num_hiddens)),
init_weight(f'W_h{name}', (self.num_hiddens, self.num_hiddens)),
self.param(f'b_{name}', nn.initializers.zeros, (self.num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple('z') # Update gate
self.W_xr, self.W_hr, self.b_r = triple('r') # Reset gate
self.W_xh, self.W_hh, self.b_h = triple('h') # Candidate hidden state
Now we are ready to [define the GRU forward computation]. Its structure is the same as that of the basic RNN cell, except that the update equations are more complex.
%%tab pytorch, mxnet, tensorflow
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
if tab.selected('mxnet'):
H = d2l.zeros((inputs.shape[1], self.num_hiddens),
ctx=inputs.ctx)
if tab.selected('pytorch'):
H = d2l.zeros((inputs.shape[1], self.num_hiddens),
device=inputs.device)
if tab.selected('tensorflow'):
H = d2l.zeros((inputs.shape[1], self.num_hiddens))
outputs = []
for X in inputs:
Z = d2l.sigmoid(d2l.matmul(X, self.W_xz) +
d2l.matmul(H, self.W_hz) + self.b_z)
R = d2l.sigmoid(d2l.matmul(X, self.W_xr) +
d2l.matmul(H, self.W_hr) + self.b_r)
H_tilde = d2l.tanh(d2l.matmul(X, self.W_xh) +
d2l.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, H
%%tab jax
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
# Use lax.scan primitive instead of looping over the
# inputs, since scan saves time in jit compilation
def scan_fn(H, X):
Z = d2l.sigmoid(d2l.matmul(X, self.W_xz) + d2l.matmul(H, self.W_hz) +
self.b_z)
R = d2l.sigmoid(d2l.matmul(X, self.W_xr) +
d2l.matmul(H, self.W_hr) + self.b_r)
H_tilde = d2l.tanh(d2l.matmul(X, self.W_xh) +
d2l.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
return H, H # return carry, y
if H is None:
batch_size = inputs.shape[1]
carry = jnp.zeros((batch_size, self.num_hiddens))
else:
carry = H
# scan takes the scan_fn, initial carry state, xs with leading axis to be scanned
carry, outputs = jax.lax.scan(scan_fn, carry, inputs)
return outputs, carry
[Training] a language model on The Time Machine dataset
works in exactly the same manner as in :numref:sec_rnn-scratch
.
%%tab all
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
if tab.selected('mxnet', 'pytorch', 'jax'):
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
if tab.selected('tensorflow'):
with d2l.try_gpu():
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1)
trainer.fit(model, data)
In high-level APIs, we can directly instantiate a GRU model. This encapsulates all the configuration detail that we made explicit above.
%%tab pytorch, mxnet, tensorflow
class GRU(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
if tab.selected('mxnet'):
self.rnn = rnn.GRU(num_hiddens)
if tab.selected('pytorch'):
self.rnn = nn.GRU(num_inputs, num_hiddens)
if tab.selected('tensorflow'):
self.rnn = tf.keras.layers.GRU(num_hiddens, return_sequences=True,
return_state=True)
%%tab jax
class GRU(d2l.RNN):
num_hiddens: int
@nn.compact
def __call__(self, inputs, H=None, training=False):
if H is None:
batch_size = inputs.shape[1]
H = nn.GRUCell.initialize_carry(jax.random.PRNGKey(0),
(batch_size,), self.num_hiddens)
GRU = nn.scan(nn.GRUCell, variable_broadcast="params",
in_axes=0, out_axes=0, split_rngs={"params": False})
H, outputs = GRU()(H, inputs)
return outputs, H
The code is significantly faster in training as it uses compiled operators rather than Python.
%%tab all
if tab.selected('mxnet', 'pytorch', 'tensorflow'):
gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
if tab.selected('jax'):
gru = GRU(num_hiddens=32)
if tab.selected('mxnet', 'pytorch', 'jax'):
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
if tab.selected('tensorflow'):
with d2l.try_gpu():
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
After training, we print out the perplexity on the training set and the predicted sequence following the provided prefix.
%%tab mxnet, pytorch
model.predict('it has', 20, data.vocab, d2l.try_gpu())
%%tab tensorflow
model.predict('it has', 20, data.vocab)
%%tab jax
model.predict('it has', 20, data.vocab, trainer.state.params)
Compared with LSTMs, GRUs achieve similar performance but tend to be lighter computationally. Generally, compared with simple RNNs, gated RNNs like LSTMs and GRUs can better capture dependencies for sequences with large time step distances. GRUs contain basic RNNs as their extreme case whenever the reset gate is switched on. They can also skip subsequences by turning on the update gate.
- Assume that we only want to use the input at time step
$t'$ to predict the output at time step$t > t'$ . What are the best values for the reset and update gates for each time step? - Adjust the hyperparameters and analyze their influence on running time, perplexity, and the output sequence.
- Compare runtime, perplexity, and the output strings for
rnn.RNN
andrnn.GRU
implementations with each other. - What happens if you implement only parts of a GRU, e.g., with only a reset gate or only an update gate?
:begin_tab:mxnet
Discussions
:end_tab:
:begin_tab:pytorch
Discussions
:end_tab:
:begin_tab:tensorflow
Discussions
:end_tab: