A vanilla implementation of a Recurrent Neural Network (RNN) with Long-Short-Term-Memory cells, without using any ML libraries.
These networks are particularly good for learning long-term dependencies within data, and can be applied to a variety of problems including language modelling, translation and speech recognition.
An LSTM cell has 4 gates, based on the following formulas:
Each gate has it's own set of paramaters to learn, which makes training vanilla implementations (such as this one) expensive.
These are collected into a single cell state value:
This is then given to a hidden state, as a normal RNN cell would: LSTM cells can effectively be treated no differently to any other cell within the network.
To initialise the network, create an instance of the class by calling the constructor with the arguments:
rnn = new LSTM_RNN(lr, in_dim, h_dim, out_dim)
Where lr
is the learning rate; in_dim
is the dimension of the input layer; h_dim
is the dimension of the hidden layer and out_dim
is the dimension of the output layer. These should correspond to your training data.
The training data should be encoded as integers, and given as two lists: a list of inputs and a corresponding one of targets. The RNN can then be trained by calling the function:
rnn.train(iterations, inputs, targets, seq_len)
Where iterations
is the number of iterations to run, inputs
and targets
are the training data, and seq_len
is the length of each batch of data.
- A sampling method to view the output of the network as it is training, using a forward pass.
- Refactor the code to use a graph of computation model.
- Use a linear sigmoid function to improve the speed.