-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathProbe.lua
56 lines (53 loc) · 1.8 KB
/
Probe.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
local Probe, parent = torch.class('nn.Probe', 'nn.Module')
function Probe:__init(name,display)
parent.__init(self)
self.name = name or 'unnamed'
self.display = display
nn._ProbeTimer = nn._ProbeTimer or torch.Timer()
end
function Probe:updateOutput(input)
self.output = input
local legend = '<' .. self.name .. '>.output'
local size = {}
for i = 1,input:dim() do
size[i] = input:size(i)
end
size = table.concat(size,'x')
local diff = nn._ProbeTimer:time().real - (nn._ProbeLast or 0)
nn._ProbeLast = nn._ProbeTimer:time().real
print('')
print(legend)
print(' + size = ' .. size)
print(' + mean = ' .. input:mean())
print(' + std = ' .. input:std())
print(' + min = ' .. input:min())
print(' + max = ' .. input:max())
print(' + time since last probe = ' .. string.format('%0.1f',diff*1000) .. 'ms')
if self.display then
self.winf = image.display{image=input, win=self.winf, legend=legend}
end
return self.output
end
function Probe:updateGradInput(input, gradOutput)
self.gradInput = gradOutput
local legend = 'layer<' .. self.name .. '>.gradInput'
local size = {}
for i = 1,gradOutput:dim() do
size[i] = gradOutput:size(i)
end
size = table.concat(size,'x')
local diff = nn._ProbeTimer:time().real - (nn._ProbeLast or 0)
nn._ProbeLast = nn._ProbeTimer:time().real
print('')
print(legend)
print(' + size = ' .. size)
print(' + mean = ' .. gradOutput:mean())
print(' + std = ' .. gradOutput:std())
print(' + min = ' .. gradOutput:min())
print(' + max = ' .. gradOutput:max())
print(' + time since last probe = ' .. string.format('%0.1f',diff*1000) .. 'ms')
if self.display then
self.winb = image.display{image=gradOutput, win=self.winb, legend=legend}
end
return self.gradInput
end