forked from HendrikStrobelt/LSTMVis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_influece_per_word.lua
126 lines (103 loc) · 3.58 KB
/
get_influece_per_word.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
require 'rnn'
require 'hdf5'
require 'nngraph'
--[[
Computes the influece of each word in a context on the final
prediction and saves it. This is NOT the saliency
The data should be formatted in a shifting window. each word
should be the last word of a context once.
Author: Sebastian Gehrmann
--]]
cmd = torch.CmdLine()
cmd:option('-data_file','data/','File with data preprocessed in shifting window')
cmd:option('-gpuid',-1,'which gpu to use. -1 = use CPU')
cmd:option('-checkpoint_file','checkpoint/','path to model checkpoint file in t7 format')
cmd:option('-output_file','reads/lstm_grads.h5','path to output LSTM gradients in hdf5 format')
cmd:option('-embedding_index',2,'The index of the embedding layer in the model.')
opt = cmd:parse(arg)
-- Construct the data set.
local data = torch.class("data")
function data:__init(opt, data_file, use_chars)
local f = hdf5.open(data_file, 'r')
self.target = f:read('target'):all()
self.use_chars = use_chars
self.target_output = f:read('target_output'):all()
self.target_size = f:read('target_size'):all()[1]
self.length = self.target:size(1)
self.seqlength = self.target:size(3)
self.batchlength = self.target:size(2)
end
function data:size()
return self.length
end
function data.__index(self, idx)
local input, target
if type(idx) == "string" then
return data[idx]
else
input = self.target[idx]:transpose(1, 2):float():cuda()
target = nn.SplitTable(2):forward(self.target_output[idx]:float():cuda())
end
return {input, target}
end
function get_influence(data, model, criterion)
local all_grads = torch.CudaTensor(data.length * data.batchlength * data.seqlength, 1)
model:training() -- makes sure the gradients are stored
for i = 1, data:size() do
if i%100 == 0 or i==2 then
print(i, "current batch")
end
model:zeroGradParameters()
local d = data[i]
input, goal = d[1], d[2]
---1. forward
local out = model:forward(input)
-- 2. backward criterion
local deriv = criterion:backward(out, goal)
-- 3. zero out everything except gradient
for z=1, #deriv-1 do
deriv[z]:fill(0)
end
-- 4. backward model
model:backward(input, deriv)
--:get(X) gets the Xth module of the model.
--It should point to the embedding layer.
local gi = model:get(opt.embedding_index).gradInput:clone()
--Construct the new index in the all_grads tensor
for csequence=1, gi:size(1) do
for cbatch=1, gi:size(2) do
local cindex = (i-1) * data.seqlength
local batchbonus = (cbatch-1) * data.length
local new_index = cindex + batchbonus + csequence
--compute the norm of the gradient as influence
all_grads[new_index] = torch.norm(gi[csequence][cbatch])
end
end
collectgarbage()
end
local f = hdf5.open(opt.output_file, 'w')
f:write('grads', all_grads:float())
f:close()
end
function main()
-- Parse input params
opt = cmd:parse(arg)
if opt.gpuid >= 0 then
print('using CUDA on GPU ' .. opt.gpuid .. '...')
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.gpuid + 1)
end
-- Create the data loader class.
local train_data = data.new(opt, opt.data_file, opt.use_chars)
-- Initialize model
criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())
model = torch.load(opt.checkpoint_file)
if opt.gpuid >= 0 then
model:cuda()
criterion:cuda()
end
-- Compute Saliency
get_influence(train_data, model, criterion)
end
main()