-
Notifications
You must be signed in to change notification settings - Fork 129
/
ntm.py
94 lines (75 loc) · 3.45 KB
/
ntm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python
import torch
from torch import nn
import torch.nn.functional as F
class NTM(nn.Module):
"""A Neural Turing Machine."""
def __init__(self, num_inputs, num_outputs, controller, memory, heads):
"""Initialize the NTM.
:param num_inputs: External input size.
:param num_outputs: External output size.
:param controller: :class:`LSTMController`
:param memory: :class:`NTMMemory`
:param heads: list of :class:`NTMReadHead` or :class:`NTMWriteHead`
Note: This design allows the flexibility of using any number of read and
write heads independently, also, the order by which the heads are
called in controlled by the user (order in list)
"""
super(NTM, self).__init__()
# Save arguments
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.controller = controller
self.memory = memory
self.heads = heads
self.N, self.M = memory.size()
_, self.controller_size = controller.size()
# Initialize the initial previous read values to random biases
self.num_read_heads = 0
self.init_r = []
for head in heads:
if head.is_read_head():
init_r_bias = torch.randn(1, self.M) * 0.01
self.register_buffer("read{}_bias".format(self.num_read_heads), init_r_bias.data)
self.init_r += [init_r_bias]
self.num_read_heads += 1
assert self.num_read_heads > 0, "heads list must contain at least a single read head"
# Initialize a fully connected layer to produce the actual output:
# [controller_output; previous_reads ] -> output
self.fc = nn.Linear(self.controller_size + self.num_read_heads * self.M, num_outputs)
self.reset_parameters()
def create_new_state(self, batch_size):
init_r = [r.clone().repeat(batch_size, 1) for r in self.init_r]
controller_state = self.controller.create_new_state(batch_size)
heads_state = [head.create_new_state(batch_size) for head in self.heads]
return init_r, controller_state, heads_state
def reset_parameters(self):
# Initialize the linear layer
nn.init.xavier_uniform_(self.fc.weight, gain=1)
nn.init.normal_(self.fc.bias, std=0.01)
def forward(self, x, prev_state):
"""NTM forward function.
:param x: input vector (batch_size x num_inputs)
:param prev_state: The previous state of the NTM
"""
# Unpack the previous state
prev_reads, prev_controller_state, prev_heads_states = prev_state
# Use the controller to get an embeddings
inp = torch.cat([x] + prev_reads, dim=1)
controller_outp, controller_state = self.controller(inp, prev_controller_state)
# Read/Write from the list of heads
reads = []
heads_states = []
for head, prev_head_state in zip(self.heads, prev_heads_states):
if head.is_read_head():
r, head_state = head(controller_outp, prev_head_state)
reads += [r]
else:
head_state = head(controller_outp, prev_head_state)
heads_states += [head_state]
# Generate Output
inp2 = torch.cat([controller_outp] + reads, dim=1)
o = F.sigmoid(self.fc(inp2))
# Pack the current state
state = (reads, controller_state, heads_states)
return o, state