Skip to content

Commit

Permalink
Move LSTM cells into separate set of layers than ordinary hidden neur…
Browse files Browse the repository at this point in the history
…on layers
  • Loading branch information
boingoing committed Apr 3, 2023
1 parent d435408 commit 68f9fca
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
33 changes: 24 additions & 9 deletions src/RecurrentNeuralNetwork.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,32 @@ size_t RecurrentNeuralNetwork::GetCellMemorySize() const {
return cell_memory_size_;
}

void RecurrentNeuralNetwork::SetHiddenLayerCount(size_t layer_count) {
assert(!IsTopologyConstructed());
// TODO(boingoing): Support resetting the count of hidden layers.
assert(GetHiddenLayerCount() == 0);

for (size_t i = 0; i < layer_count; i++) {
AddHiddenLayer(0);
void RecurrentNeuralNetwork::AddHiddenLayer(size_t cell_count, const std::vector<size_t>& cell_memory_sizes) {
const size_t cell_index = cells_.size();
layers_.emplace_back(cell_index, cell_count);

for (size_t i = 0; i < cell_count; i++) {
const size_t cell_hidden_neuron_start_index = GetHiddenNeuronStartIndex() + GetHiddenNeuronCount();
const size_t cell_memory_size = cell_memory_sizes.size() <= i || cell_memory_sizes[i] == 0 ? GetCellMemorySize() : cell_memory_sizes[i];
const size_t cell_states_index = cell_states_count_;
const size_t neurons_per_gate = cell_memory_size;
// Forget gate, input gate, output gate, candidate cell state layer, hidden state layer.
const size_t neurons_per_cell = neurons_per_gate * 5;

cells_.emplace_back(cell_hidden_neuron_start_index,neurons_per_cell,cell_states_index,cell_memory_size);

cell_states_count_ += cell_memory_size;
AddHiddenNeurons(neurons_per_cell);
}
}

void RecurrentNeuralNetwork::Allocate() {
assert(!IsTopologyConstructed());

cell_states_.resize(cell_states_count_);

NeuralNetwork::Allocate();

assert(GetHiddenLayerCount() > 0);

// Allocate the neurons and memory cells
Expand Down Expand Up @@ -269,7 +283,7 @@ void RecurrentNeuralNetwork::ConnectFully() {
ConnectBiasNeuron(bias_neuron_index, GetOutputNeuronStartIndex(), GetOutputNeuronCount());
}

void RecurrentNeuralNetwork::UpdateCellState(LongShortTermMemoryCell& cell) {
void RecurrentNeuralNetwork::UpdateCellState(const LongShortTermMemoryCell& cell) {
const size_t forget_gate_neuron_start_index = cell.neuron_start_index;
const size_t input_gate_neuron_start_index = forget_gate_neuron_start_index + cell.cell_state_count;
const size_t candidate_cell_state_gate_neuron_start_index = input_gate_neuron_start_index + cell.cell_state_count;
Expand Down Expand Up @@ -305,10 +319,11 @@ void RecurrentNeuralNetwork::RunForward(const std::vector<double>& input) {

// Feed each input into the corresponding input neuron.
for (size_t i = 0; i < GetInputNeuronCount(); i++) {
auto& neuron = GetNeuron(GetInputNeuronStartIndex() + i);
auto& neuron = GetInputNeuron(i);
neuron.value = input[i];
}

// Update cell states.
for (auto& cell : cells_) {
UpdateCellState(cell);
}
Expand Down
25 changes: 20 additions & 5 deletions src/RecurrentNeuralNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
namespace panann {

/**
* A recurrent artificial neural network made out of long short term memory cells.<br/>
*
* This network doesn't contain ordinary hidden neurons organized into layers. Instead, each layer contains a set of recurrent cells which are each made of a number of hidden units grouped into gates.
*/
class RecurrentNeuralNetwork : public NeuralNetwork {
protected:
Expand All @@ -32,12 +34,16 @@ class RecurrentNeuralNetwork : public NeuralNetwork {
size_t cell_state_start_index;

/**
* Count of cell states belonging to the cell.<br/>
* Note: All cells currently have the same count of cell states and that count equals |cell_memory_size_|.
* Count of cell states belonging to the cell.
*/
size_t cell_state_count;
};

struct CellLayer {
size_t cell_start_index;
size_t cell_count;
};

public:
RecurrentNeuralNetwork() = default;
RecurrentNeuralNetwork(const RecurrentNeuralNetwork&) = delete;
Expand All @@ -52,26 +58,35 @@ class RecurrentNeuralNetwork : public NeuralNetwork {
void SetCellMemorySize(size_t memory_size);
size_t GetCellMemorySize() const;

void SetHiddenLayerCount(size_t layer_count);

void RunForward(const std::vector<double>& input) override;

/**
* Get a writable vector of memory state for all cells in the network.
*/
std::vector<double>& GetCellStates();

void AddHiddenLayer(size_t neuron_count) = delete;

/**
* Add a hidden layer of LSTM cells.<br/>
* Each cell may have a different cell memory size passed via |cell_memory_sizes|. If the vector doesn't contain an element for a cell or if that element is 0, the cell memory size for that cell will be the default returned via GetCellMemorySize().
* @see GetCellMemorySize()
*/
void AddHiddenLayer(size_t cell_count, const std::vector<size_t>& cell_memory_sizes);

protected:
void Allocate() override;
void ConnectFully() override;

void UpdateCellState(LongShortTermMemoryCell& cell);
void UpdateCellState(const LongShortTermMemoryCell& cell);

private:
static constexpr size_t DefaultCellMemorySize = 200;

std::vector<CellLayer> layers_;
std::vector<LongShortTermMemoryCell> cells_;
std::vector<double> cell_states_;
size_t cell_states_count_ = 0;
size_t cell_memory_size_ = DefaultCellMemorySize;
};

Expand Down

0 comments on commit 68f9fca

Please sign in to comment.