forked from jcjohnson/torch-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLanguageModel.lua
149 lines (119 loc) · 3.75 KB
/
LanguageModel.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
require 'torch'
require 'nn'
require 'SequenceRNN'
require 'SequenceLSTM'
local utils = require 'utils'
local LM, parent = torch.class('nn.LanguageModel', 'nn.Module')
function LM:__init(kwargs)
self.idx_to_token = utils.get_kwarg(kwargs, 'idx_to_token')
self.token_to_idx = {}
self.vocab_size = 0
for idx, token in pairs(self.idx_to_token) do
self.token_to_idx[token] = idx
self.vocab_size = self.vocab_size + 1
end
self.cell_type = utils.get_kwarg(kwargs, 'cell_type', 'lstm')
self.wordvec_dim = utils.get_kwarg(kwargs, 'wordvec_dim', 128)
self.hidden_dim = utils.get_kwarg(kwargs, 'hidden_dim', 256)
self.num_layers = utils.get_kwarg(kwargs, 'num_layers', 1)
local V, D, H = self.vocab_size, self.wordvec_dim, self.hidden_dim
self.net = nn.Sequential()
self.rnns = {}
self.net:add(nn.LookupTable(V, D))
for i = 1, self.num_layers do
local prev_dim = H
if i == 1 then prev_dim = D end
local rnn
if self.cell_type == 'rnn' then
rnn = nn.SequenceRNN(prev_dim, H)
elseif self.cell_type == 'lstm' then
rnn = nn.SequenceLSTM(prev_dim, H)
end
rnn.remember_states = true
table.insert(self.rnns, rnn)
self.net:add(rnn)
end
-- After all the RNNs run, we will have a tensor of shape (N, T, H);
-- we want to apply a 1D temporal convolution to predict scores for each
-- vocab element, giving a tensor of shape (N, T, V). Unfortunately
-- nn.TemporalConvolution is SUPER slow, so instead we will use a pair of
-- views (N, T, H) -> (NT, H) and (NT, V) -> (N, T, V) with a nn.Linear in
-- between. Unfortunately N and T can change on every minibatch, so we need
-- to set them in the forward pass.
self.view1 = nn.View(1, 1, -1):setNumInputDims(3)
self.view2 = nn.View(1, -1):setNumInputDims(2)
self.net:add(self.view1)
self.net:add(nn.Linear(H, V))
self.net:add(self.view2)
end
function LM:updateOutput(input)
local N, T = input:size(1), input:size(2)
self.view1:resetSize(N * T, -1)
self.view2:resetSize(N, T, -1)
return self.net:forward(input)
end
function LM:backward(input, gradOutput, scale)
return self.net:backward(input, gradOutput, scale)
end
function LM:parameters()
return self.net:parameters()
end
function LM:resetStates()
for i, rnn in ipairs(self.rnns) do
rnn:resetStates()
end
end
function LM:encode_string(s)
local encoded = torch.LongTensor(#s)
for i = 1, #s do
local token = s:sub(i, i)
local idx = self.token_to_idx[token]
assert(idx ~= nil, 'Got invalid idx')
encoded[i] = idx
end
return encoded
end
function LM:decode_string(encoded)
assert(torch.isTensor(encoded) and encoded:dim() == 1)
local s = ''
for i = 1, encoded:size(1) do
local idx = encoded[i]
local token = self.idx_to_token[idx]
s = s .. self.idx_to_token[encoded[i]]
end
return s
end
--[[
Sample from the language model. Note that this will reset the states of the
underlying RNNs.
Inputs:
- init: String of length T0
- max_length: Number of characters to sample
Returns:
- sampled: (1, max_length) array of integers, where the first part is init.
--]]
function LM:sample(init, max_length)
local return_string = false
if torch.type(init) == 'string' then
return_string = true
init = self:encode_string(init):view(1, -1)
end
local T0, T = init:size(2), max_length
local sampled = torch.LongTensor(1, T)
sampled[{{}, {1, T0}}]:copy(init)
self:resetStates()
self:resetStates()
local scores = self:forward(init)[{{}, {T0, T0}}]
for t = T0 + 1, T do
local _, next_char = scores:max(3)
next_char = next_char[{{}, {}, 1}]
sampled[{{}, {t, t}}]:copy(next_char)
scores = self:forward(next_char)
end
self:resetStates()
if return_string then
return self:decode_string(sampled[1])
else
return sampled
end
end