forked from OpenNMT/OpenNMT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrelease_model.lua
116 lines (95 loc) · 2.78 KB
/
release_model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
require('onmt.init')
local path = require('pl.path')
local cmd = onmt.utils.ExtendedCmdLine.new('release_model.lua')
local options = {
{
'-model', '',
[[Path to the trained model to release.]],
{
valid = onmt.utils.ExtendedCmdLine.fileExists
}
},
{
'-output_model', '',
[[Path the released model. If not set, the `release` suffix will be automatically
added to the model filename.]]
},
{
'-force', false,
[[Force output model creation even if the target file exists.]]
}
}
cmd:setCmdLineOptions(options, 'Model')
onmt.utils.Cuda.declareOpts(cmd)
onmt.utils.Logger.declareOpts(cmd)
local opt = cmd:parse(arg)
local function isModel(object)
return torch.type(object) == 'table' and object.modules
end
local function releaseModule(object, tensorCache)
tensorCache = tensorCache or {}
if object.release then
object:release()
end
object:float(tensorCache)
object:clearState()
object:apply(function (m)
nn.utils.clear(m, 'gradWeight', 'gradBias')
for k, v in pairs(m) do
if type(v) == 'function' then
m[k] = nil
end
end
end)
end
local function releaseModel(model, tensorCache)
tensorCache = tensorCache or {}
for _, object in pairs(model.modules) do
if isModel(object) then
releaseModel(object, tensorCache)
else
releaseModule(object, tensorCache)
end
end
end
local function main()
assert(path.exists(opt.model), 'model \'' .. opt.model .. '\' does not exist.')
_G.logger = onmt.utils.Logger.new(opt.log_file, opt.disable_logs, opt.log_level, opt.log_tag)
if opt.output_model:len() == 0 then
if opt.model:sub(-3) == '.t7' then
opt.output_model = opt.model:sub(1, -4) -- copy input model without '.t7' extension
else
opt.output_model = opt.model
end
opt.output_model = opt.output_model .. '_release.t7'
end
if not opt.force then
assert(not path.exists(opt.output_model),
'output model already exists; use -force to overwrite.')
end
onmt.utils.Cuda.init(opt)
_G.logger:info('Loading model \'' .. opt.model .. '\'...')
local checkpoint
local _, err = pcall(function ()
checkpoint = torch.load(opt.model)
end)
if err then
error('unable to load the model (' .. err .. '). If you are releasing a GPU model, it needs to be loaded on the GPU first (set -gpuid > 0)')
end
_G.logger:info('... done.')
_G.logger:info('Converting model...')
checkpoint.info = nil
for _, object in pairs(checkpoint.models) do
if isModel(object) then
releaseModel(object)
else
releaseModule(object)
end
end
_G.logger:info('... done.')
_G.logger:info('Releasing model \'' .. opt.output_model .. '\'...')
torch.save(opt.output_model, checkpoint)
_G.logger:info('... done.')
_G.logger:shutDown()
end
main()