-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathSoftMaxTree.lua
372 lines (344 loc) · 13.3 KB
/
SoftMaxTree.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
local SoftMaxTree, parent = torch.class('nn.SoftMaxTree', 'nn.Module')
------------------------------------------------------------------------
--[[ SoftMaxTree ]]--
-- Computes the log of a product of softmaxes in a path
-- Returns an output tensor of size 1D
-- Only works with a tree (one parent per child)
------------------------------------------------------------------------
function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, verbose)
parent.__init(self)
self.rootId = rootId or 1
self.inputSize = inputSize
self.accUpdate = accUpdate
assert(type(hierarchy) == 'table', "Expecting table at arg 2")
-- get the total amount of children (non-root nodes)
local nChildNode = 0
local nParentNode = 0
local maxNodeId = -999999999
local minNodeId = 999999999
local maxParentId = -999999999
local maxChildId = -999999999
local maxFamily = -999999999
local parentIds = {}
for parentId, children in pairs(hierarchy) do
assert(children:dim() == 1, "Expecting table of 1D tensors at arg 2")
nChildNode = nChildNode + children:size(1)
nParentNode = nParentNode + 1
maxParentId = math.max(parentId, maxParentId)
maxFamily = math.max(maxFamily, children:size(1))
local maxChildrenId = children:max()
maxChildId = math.max(maxChildrenId, maxChildId)
maxNodeId = math.max(parentId, maxNodeId, maxChildrenId)
minNodeId = math.min(parentId, minNodeId, children:min())
table.insert(parentIds, parentId)
end
if minNodeId < 0 then
error("nodeIds must must be positive: "..minNodeId, 2)
end
if verbose then
print("Hierachy has :")
print(nParentNode.." parent nodes")
print(nChildNode.." child nodes")
print((nChildNode - nParentNode).." leaf nodes")
print("node index will contain "..maxNodeId.." slots")
if maxNodeId ~= (nChildNode + 1) then
print("Warning: Hierarchy has more nodes than Ids")
print("Consider making your nodeIds a contiguous sequence ")
print("in order to waste less memory on indexes.")
end
end
self.nChildNode = nChildNode
self.nParentNode = nParentNode
self.minNodeId = minNodeId
self.maxNodeId = maxNodeId
self.maxParentId = maxParentId
self.maxChildId = maxChildId
self.maxFamily = maxFamily
-- initialize weights and biases
self.weight = torch.Tensor(self.nChildNode, self.inputSize)
self.bias = torch.Tensor(self.nChildNode)
if not self.accUpdate then
self.gradWeight = torch.Tensor(self.nChildNode, self.inputSize)
self.gradBias = torch.Tensor(self.nChildNode)
end
-- contains all childIds
self.childIds = torch.IntTensor(self.nChildNode)
-- contains all parentIds
self.parentIds = torch.IntTensor(parentIds)
-- index of children by parentId
self.parentChildren = torch.IntTensor(self.maxParentId, 2):fill(-1)
local start = 1
for parentId, children in pairs(hierarchy) do
local node = self.parentChildren:select(1, parentId)
node[1] = start
local nChildren = children:size(1)
node[2] = nChildren
self.childIds:narrow(1, start, nChildren):copy(children)
start = start + nChildren
end
-- index of parent by childId
self.childParent = torch.IntTensor(self.maxChildId, 2):fill(-1)
for parentIdx=1,self.parentIds:size(1) do
local parentId = self.parentIds[parentIdx]
local node = self.parentChildren:select(1, parentId)
local start = node[1]
local nChildren = node[2]
local children = self.childIds:narrow(1, start, nChildren)
for childIdx=1,children:size(1) do
local childId = children[childIdx]
local child = self.childParent:select(1, childId)
child[1] = parentId
child[2] = childIdx
end
end
-- used to allocate buffers
-- max nChildren in family path
local maxFamilyPath = -999999999
-- max number of parents
local maxDept = -999999999
local treeSizes = {[rootId] = self.parentChildren[rootId][2]}
local pathSizes = {[rootId] = 1}
local function getSize(nodeId)
local treeSize, pathSize = treeSizes[nodeId], pathSizes[nodeId]
if not treeSize then
local parentId = self.childParent[nodeId][1]
local nChildren = self.parentChildren[nodeId][2]
treeSize, pathSize = getSize(parentId)
treeSize = treeSize + nChildren
pathSize = pathSize + 1
treeSizes[nodeId] = treeSize
pathSizes[nodeId] = pathSize
end
return treeSize, pathSize
end
for parentIdx=1,self.parentIds:size(1) do
local parentId = self.parentIds[parentIdx]
local treeSize, pathSize = getSize(parentId)
maxFamilyPath = math.max(treeSize, maxFamilyPath)
maxDept = math.max(pathSize, maxDept)
end
self.maxFamilyPath = maxFamilyPath
self.maxDept = maxDept
-- stores the parentIds of nodes that have been accGradParameters
self.updates = {}
-- used internally to store intermediate outputs or gradOutputs
self._nodeBuffer = torch.Tensor()
self._multiBuffer = torch.Tensor()
self.batchSize = 0
self._gradInput = torch.Tensor()
self._gradTarget = torch.IntTensor() -- dummy
self.gradInput = {self._gradInput, self._gradTarget}
self.static = (static == nil) and true or static
self:reset()
end
function SoftMaxTree:reset(stdv)
if stdv then
stdv = stdv * math.sqrt(3)
else
stdv = 1/math.sqrt(self.nChildNode*self.inputSize)
end
self.weight:uniform(-stdv, stdv)
self.bias:uniform(-stdv, stdv)
end
function SoftMaxTree:updateOutput(inputTable)
local input, target = unpack(inputTable)
-- buffers:
if self.batchSize ~= input:size(1) then
self._nodeBuffer:resize(self.maxFamily)
self._multiBuffer:resize(input:size(1)*self.maxFamilyPath)
self.batchSize = input:size(1)
-- so that it works within nn.ConcatTable :
self._gradTarget:resizeAs(target):zero()
if self._nodeUpdateHost then
self._nodeUpdateHost:resize(input:size(1),self.maxDept)
self._nodeUpdateCuda:resize(input:size(1),self.maxDept)
end
end
return input.nn.SoftMaxTree_updateOutput(self, input, target)
end
function SoftMaxTree:updateGradInput(inputTable, gradOutput)
local input, target = unpack(inputTable)
if not gradOutput:isContiguous() and torch.type(gradOutput) == 'torch.CudaTensor' then
self._gradOutput = self._gradOutput or gradOutput.new()
self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
gradOutput = self._gradOutput
end
if self.gradInput then
input.nn.SoftMaxTree_updateGradInput(self, input, gradOutput, target)
end
return self.gradInput
end
function SoftMaxTree:accGradParameters(inputTable, gradOutput, scale)
local input, target = unpack(inputTable)
gradOutput = self._gradOutput or gradOutput
scale = scale or 1
input.nn.SoftMaxTree_accGradParameters(self, input, gradOutput, target, scale)
end
-- when static is true, return parameters with static keys
-- i.e. keys that don't change from batch to batch
function SoftMaxTree:parameters()
local static = self.static
local params, grads = {}, {}
local updated = false
for parentId, scale in pairs(self.updates) do
local node = self.parentChildren:select(1, parentId)
local parentIdx = node[1]
local nChildren = node[2]
if static then -- for use with pairs
params[parentId] = self.weight:narrow(1, parentIdx, nChildren)
local biasId = parentId+self.maxParentId
params[biasId] = self.bias:narrow(1, parentIdx, nChildren)
if not self.accUpdate then
grads[parentId] = self.gradWeight:narrow(1, parentIdx, nChildren)
grads[biasId] = self.gradBias:narrow(1, parentIdx, nChildren)
end
else -- for use with ipairs
table.insert(params, self.weight:narrow(1, parentIdx, nChildren))
table.insert(params, self.bias:narrow(1, parentIdx, nChildren))
if not self.accUpdate then
table.insert(grads, self.gradWeight:narrow(1, parentIdx, nChildren))
table.insert(grads, self.gradBias:narrow(1, parentIdx, nChildren))
end
end
updated = true
end
if not updated then
if static then -- consistent with static = true
for i=1,self.parentIds:size(1) do
local parentId = self.parentIds[i]
local node = self.parentChildren:select(1, parentId)
local parentIdx = node[1]
local nChildren = node[2]
params[parentId] = self.weight:narrow(1, parentIdx, nChildren)
local biasId = parentId+self.maxParentId
params[biasId] = self.bias:narrow(1, parentIdx, nChildren)
if not self.accUpdate then
grads[parentId] = self.gradWeight:narrow(1, parentIdx, nChildren)
grads[biasId] = self.gradBias:narrow(1, parentIdx, nChildren)
end
end
else
return {self.weight, self.bias}, {self.gradWeight, self.gradBias}
end
end
return params, grads, {}, self.nChildNode*2
end
function SoftMaxTree:updateParameters(learningRate)
assert(not self.accUpdate)
local params, gradParams = self:parameters()
if params then
for k,param in pairs(params) do
param:add(-learningRate, gradParams[k])
end
end
end
function SoftMaxTree:getNodeParameters(parentId)
local node = self.parentChildren:select(1,parentId)
local start = node[1]
local nChildren = node[2]
local weight = self.weight:narrow(1, start, nChildren)
local bias = self.bias:narrow(1, start, nChildren)
if not self.accUpdate then
local gradWeight = self.gradWeight:narrow(1, start, nChildren)
local gradBias = self.gradBias:narrow(1, start, nChildren)
return {weight, bias}, {gradWeight, gradBias}
end
return {weight, bias}
end
function SoftMaxTree:zeroGradParameters()
local _,gradParams = self:parameters()
for k,gradParam in pairs(gradParams) do
gradParam:zero()
end
-- loop is used instead of 'self.updates = {}'
-- to handle the case when updates are shared
for k,v in pairs(self.updates) do
self.updates[k] = nil
end
end
function SoftMaxTree:type(type, typecache)
if type == torch.type(self.weight) then
return self
end
local hierarchy = self.hierarchy
self.hierarchy = nil
self._nodeUpdateHost = nil
self._nodeUpdateCuda = nil
self._paramUpdateHost = nil
self._paramUpdateCuda = nil
local parentChildren = self.parentChildren
self.parentChildren = nil
self.parentChildrenCuda = nil
local childParent = self.childParent
self.childParent = nil
self.childParentCuda = nil
local _gradTarget = self._gradTarget
self._gradTarget = nil
local childIds = self.childIds
self.childIds = nil
local parentIds = self.parentIds
self.parentIds = nil
self._gradOutput = nil
parent.type(self, type, typecache)
self.hierarchy = hierarchy
self.parentChildren = parentChildren
self.childParent = childParent
self._gradTarget = _gradTarget
self.childIds = childIds
self.parentIds = parentIds
if (type == 'torch.CudaTensor') then
-- cunnx needs this for filling self.updates
self._nodeUpdateHost = torch.IntTensor()
self._nodeUpdateCuda = torch.CudaIntTensor()
self._paramUpdateHost = torch.IntTensor()
self._paramUpdateCuda = torch.CudaTensor()
self.parentChildrenCuda = self.parentChildren:type(type)
self.childParentCuda = self.childParent:type(type)
self._gradTarget = self._gradTarget:type(type)
elseif self._nodeUpdateHost then
self._nodeUpdateHost = nil
self._nodeUpdateCuda = nil
self.parentChildren = self.parentChildren:type('torch.IntTensor')
self.childParent = self.childParent:type('torch.IntTensor')
self._gradTarget = self._gradTarget:type('torch.IntTensor')
end
self.gradInput = {self._gradInput, self._gradTarget}
self.batchSize = 0 --so that buffers are resized
return self
end
function SoftMaxTree:maxNorm(maxNorm)
local params = self:parameters()
if params then
for k,param in pairs(params) do
if param:dim() == 2 and maxNorm then
param:renorm(2,1,maxNorm)
end
end
end
end
function SoftMaxTree:momentumGradParameters()
-- get dense view of momGradParams
local _ = require 'moses'
if not self.momGradParams or _.isEmpty(self.momGradParams) then
assert(not self.accUpdate, "cannot use momentum with accUpdate")
self.momGradParams = {self.gradWeight:clone():zero(), self.gradBias:clone():zero()}
end
local momGradParams = self.momGradParams
if self.static and not _.isEmpty(self.updates) then
local momGradWeight = momGradParams[1]
local momGradBias = momGradParams[2]
momGradParams = {}
-- only return the parameters affected by the forward/backward
for parentId, scale in pairs(self.updates) do
local node = self.parentChildren:select(1, parentId)
local parentIdx = node[1]
local nChildren = node[2]
momGradParams[parentId] = momGradWeight:narrow(1, parentIdx, nChildren)
local biasId = parentId+self.maxParentId
momGradParams[biasId] = momGradBias:narrow(1, parentIdx, nChildren)
end
end
return momGradParams
end
-- we do not need to accumulate parameters when sharing
SoftMaxTree.sharedAccUpdateGradParameters = SoftMaxTree.accUpdateGradParameters