forked from OpenNMT/OpenNMT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrest_translation_server.lua
185 lines (161 loc) · 5.27 KB
/
rest_translation_server.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/env lua
--[[
This requires the restserver-xavante rock to run.
run server (this file)
th tools/rest_translation_server.lua -model ../Recipes/baseline-1M-enfr/exp/model-baseline-1M-enfr_epoch13_3.44.t7 -gpuid 1
query the server:
curl -v -H "Content-Type: application/json" -X POST -d '[{ "src" : "international migration" }]' http://127.0.0.1:7784/translator/translate
]]
require('onmt.init')
local tokenizer = require('tools.utils.tokenizer')
local BPE = require ('tools.utils.BPE')
local restserver = require("tools.restserver.restserver")
local cmd = onmt.utils.ExtendedCmdLine.new('rest_translation_server.lua')
local options = {
{
'-host', '127.0.0.1',
[[Host to run the server on.]]
},
{
'-port', '7784',
[[Port to run the server on.]]
},
{
'-withAttn', false,
[[If set returns by default attn vector.]]
}
}
cmd:setCmdLineOptions(options, 'Server')
onmt.translate.Translator.declareOpts(cmd)
onmt.utils.Cuda.declareOpts(cmd)
onmt.utils.Logger.declareOpts(cmd)
tokenizer.declareOpts(cmd)
onmt.utils.HookManager.updateOpt(arg, cmd)
onmt.utils.HookManager.declareOpts(cmd)
cmd:text("")
cmd:text("Other options")
cmd:text("")
cmd:option('-batch_size', 64, [[Size of each parallel batch - you should not change except if low memory.]])
local opt = cmd:parse(arg)
local function translateMessage(translator, lines)
local bpe
local res
local err
_G.logger:debug("Start Tokenization")
if opt.bpe_model ~= '' then
bpe = BPE.new(opt)
end
local i = 1
local translations = {}
while i <= #lines do
local batch = {}
while i <= #lines and #batch < opt.batch_size do
local srcTokens = {}
local srcTokenized = {}
local tokens
res, err = pcall(function()
local preprocessed = _G.hookManager:call("mpreprocess", opt, lines[i].src) or lines[i].src
tokens = tokenizer.tokenize(opt, preprocessed, bpe)
end)
-- it can generate an exception if there are utf-8 issues in the text
if not res then
if string.find(err, "interrupted") then
error("interrupted")
else
error("unicode error in line " .. err)
end
end
-- Add custom source features if they are provided in the request. This is usually used for domain control.
if lines[i].feats then
for j = 1, #tokens do
for _, feat in ipairs(lines[i].feats) do
if feat ~= '' then
tokens[j] = tokens[j] .. '│' .. feat
end
end
end
end
table.insert(srcTokenized, table.concat(tokens, ' '))
-- Extract from the line.
for word in srcTokenized[1]:gmatch'([^%s]+)' do
table.insert(srcTokens, word)
end
-- Currently just a single batch.
table.insert(batch, translator:buildInput(srcTokens))
i = i + 1
end
-- Translate
_G.logger:debug("Start Translation")
local results = translator:translate(batch)
_G.logger:debug("End Translation")
-- Return the nbest translations for each in the batch.
for b = 1, #batch do
local ret = {}
for bi = 1, translator.args.n_best do
local srcSent = translator:buildOutput(batch[b])
local predSent
res, err = pcall(function()
predSent = tokenizer.detokenize(opt,
results[b].preds[bi].words,
results[b].preds[bi].features)
end)
if not res then
if string.find(err,"interrupted") then
error("interrupted")
else
error("unicode error in line ".. err)
end
end
local lineres = {
tgt = predSent,
src = srcSent,
n_best = bi,
pred_score = results[b].preds[bi].score
}
if opt.withAttn or lines[b].withAttn then
local attnTable = {}
for j = 1, #results[b].preds[bi].attention do
table.insert(attnTable, results[b].preds[bi].attention[j]:totable())
end
lineres.attn = attnTable
end
table.insert(ret, lineres)
end
table.insert(translations, ret)
end
end
return translations
end
local function init_server(host, port, translator)
local server = restserver:new():host(host):port(port)
server:add_resource("translator", {
{
method = "POST",
path = "/translate",
consumes = "application/json",
produces = "application/json",
handler = function(req)
_G.logger:debug("receiving request")
local translate = translateMessage(translator, req)
_G.logger:debug("sending response")
return restserver.response():status(200):entity(translate)
end,
}
})
return server
end
local function main()
-- load logger
_G.logger = onmt.utils.Logger.new(opt.log_file, opt.disable_logs, opt.log_level, opt.log_tag)
onmt.utils.Cuda.init(opt)
_G.hookManager = onmt.utils.HookManager.new(opt)
-- disable profiling
_G.profiler = onmt.utils.Profiler.new(false)
_G.logger:info("Loading model")
local translator = onmt.translate.Translator.new(opt)
_G.logger:info("Launch server")
local server = init_server(opt.host, opt.port, translator)
-- This loads the restserver.xavante plugin
server:enable("restserver.xavante"):start()
end
main()