-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinit.lua
97 lines (88 loc) · 3.04 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
require 'nn'
nn.config = {}
nn.config.prettyPrint = true
function nn.Container:prettyPrint(status)
if status == nil then
nn.config.prettyPrint = not nn.config.prettyPrint
else
nn.config.prettyPrint = status
end
end
function nn.Sequential:__tostring__()
local b = function(s) -- BLUE
if nn.config.prettyPrint then return '\27[0;34m' .. s .. '\27[0m' end
return s
end
local tab = ' '
local line = '\n'
local next = b ' -> '
local str = b 'nn.Sequential'
str = str .. b ' {' .. line .. tab .. b '[input'
for i=1,#self.modules do
str = str .. next .. b '(' .. i .. b ')'
end
str = str .. next .. b 'output]'
for i=1,#self.modules do
str = str .. line .. tab .. b '(' .. i .. b '): ' .. tostring(self.modules[i]):gsub(line, line .. tab)
end
str = str .. line .. b '}'
return str
end
--------------------------------------------------------------------------------
-- Concat
--------------------------------------------------------------------------------
function nn.Concat:__tostring__()
local r = function(s) -- RED
if nn.config.prettyPrint then return '\27[0;31m' .. s .. '\27[0m' end
return s
end
local tab = ' '
local line = '\n'
local next = r ' |`-> '
local lastNext = r ' `-> '
local ext = r ' | '
local extlast = ' '
local last = r ' ... -> '
local str = r(torch.type(self))
str = str .. r ' {' .. line .. tab .. r 'input'
for i=1,#self.modules do
if i == #self.modules then
str = str .. line .. tab .. lastNext .. r '(' .. i .. r '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
else
str = str .. line .. tab .. next .. r '(' .. i .. r '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
end
end
str = str .. line .. tab .. last .. r 'output'
str = str .. line .. r '}'
return str
end
nn.ConcatTable.__tostring__ = nn.Concat.__tostring__
--------------------------------------------------------------------------------
-- Parallel
--------------------------------------------------------------------------------
function nn.Parallel:__tostring__()
local g = function(s) -- GREEN
if nn.config.prettyPrint then return '\27[0;32m' .. s .. '\27[0m' end
return s
end
local tab = ' '
local line = '\n'
local next = g ' |`-> '
local lastNext = g ' `-> '
local ext = g ' | '
local extlast = ' '
local last = g ' ... -> '
local str = g(torch.type(self))
str = str .. g ' {' .. line .. tab .. g 'input'
for i=1,#self.modules do
if i == #self.modules then
str = str .. line .. tab .. lastNext .. g '(' .. i .. g '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
else
str = str .. line .. tab .. next .. g '(' .. i .. g '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
end
end
str = str .. line .. tab .. last .. g 'output'
str = str .. line .. g '}'
return str
end
nn.ParallelTable.__tostring__ = nn.Parallel.__tostring__