Skip to content

Commit 8872895

Browse files
committed
Also add a function handler.
1 parent 2e43447 commit 8872895

File tree

6 files changed

+102
-12
lines changed

6 files changed

+102
-12
lines changed

http/api_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,13 @@ func TestMTLSServerWithClient(t *testing.T) {
262262
)
263263
assert.NotZero(t, tests.RunLuaTestFile(t, preload, "test/test_mtls_server_with_client.lua"))
264264
}
265+
266+
func TestServer(t *testing.T) {
267+
preload := tests.SeveralPreloadFuncs(
268+
lua_http.Preload,
269+
lua_time.Preload,
270+
inspect.Preload,
271+
plugin.Preload,
272+
)
273+
assert.NotZero(t, tests.RunLuaTestFile(t, preload, "test/test_server.lua"))
274+
}

http/loader.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,11 @@ func Loader(L *lua.LState) int {
4949
http_server_ud := L.NewTypeMetatable(`http_server_ud`)
5050
L.SetGlobal(`http_server_ud`, http_server_ud)
5151
L.SetField(http_server_ud, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
52-
"accept": server.Accept,
53-
"addr": server.Addr,
54-
"do_handle_file": server.HandleFile,
55-
"do_handle_string": server.HandleString,
52+
"accept": server.Accept,
53+
"addr": server.Addr,
54+
"do_handle_file": server.HandleFile,
55+
"do_handle_string": server.HandleString,
56+
"do_handle_function": server.HandleFunction,
5657
}))
5758

5859
t := L.NewTable()

http/server/loader.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ func Loader(L *lua.LState) int {
3030
http_server_ud := L.NewTypeMetatable(`http_server_ud`)
3131
L.SetGlobal(`http_server_ud`, http_server_ud)
3232
L.SetField(http_server_ud, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
33-
"accept": Accept,
34-
"addr": Addr,
35-
"do_handle_file": HandleFile,
36-
"do_handle_string": HandleString,
33+
"accept": Accept,
34+
"addr": Addr,
35+
"do_handle_file": HandleFile,
36+
"do_handle_string": HandleString,
37+
"do_handle_function": HandleFunction,
3738
}))
3839

3940
t := L.NewTable()

http/server/server.go

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ func HandleFile(L *lua.LState) int {
236236
case data := <-s.serveData:
237237
go func(sData *serveData, filename string) {
238238
state := newHandlerState(data)
239+
defer state.Close()
239240
if err := state.DoFile(filename); err != nil {
240241
log.Printf("[ERROR] handle file %s: %s\n", filename, err.Error())
241242
data.done <- true
@@ -248,7 +249,7 @@ func HandleFile(L *lua.LState) int {
248249
return 0
249250
}
250251

251-
// HandleString lua http_server_ud:handler_string(body)
252+
// HandleString lua http_server_ud:handle_string(body)
252253
func HandleString(L *lua.LState) int {
253254
s := checkServer(L, 1)
254255
body := L.CheckString(2)
@@ -257,6 +258,7 @@ func HandleString(L *lua.LState) int {
257258
case data := <-s.serveData:
258259
go func(sData *serveData, content string) {
259260
state := newHandlerState(sData)
261+
defer state.Close()
260262
if err := state.DoString(content); err != nil {
261263
log.Printf("[ERROR] handle: %s\n", err.Error())
262264
data.done <- true
@@ -268,6 +270,48 @@ func HandleString(L *lua.LState) int {
268270
return 0
269271
}
270272

273+
// HandleFunction lua http_server_ud:handle_function(func(response, request))
274+
func HandleFunction(L *lua.LState) int {
275+
s := checkServer(L, 1)
276+
f := L.CheckFunction(2)
277+
if len(f.Upvalues) > 0 {
278+
L.ArgError(2, "cannot pass closures")
279+
}
280+
281+
// Stash any args to pass to the function beyond response and request
282+
var args []lua.LValue
283+
top := L.GetTop()
284+
for i := 3; i <= top; i++ {
285+
args = append(args, L.Get(i))
286+
}
287+
288+
for {
289+
select {
290+
case data := <-s.serveData:
291+
go func(sData *serveData) {
292+
state := newHandlerState(sData)
293+
defer state.Close()
294+
response := state.GetGlobal("response")
295+
request := state.GetGlobal("request")
296+
f := state.NewFunctionFromProto(f.Proto)
297+
state.Push(f)
298+
state.Push(response)
299+
state.Push(request)
300+
// Push any extra args
301+
for _, arg := range args {
302+
state.Push(arg)
303+
}
304+
if err := state.PCall(2+len(args), 0, nil); err != nil {
305+
log.Printf("[ERROR] handle: %s\n", err.Error())
306+
data.done <- true
307+
log.Printf("[ERROR] closed connection\n")
308+
}
309+
state.Pop(state.GetTop())
310+
}(data)
311+
}
312+
}
313+
}
314+
271315
// ServeHTTP interface realisation
272316
func (s *luaServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
273317
doneChan := make(chan bool)

http/test/test_mtls_server_with_client.lua

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@ function TestMTLSServerWithClient(t)
1414
}
1515
assert(not err, tostring(err))
1616
addr_ch:send(server:addr())
17-
while true do
18-
local request, response = server:accept()
17+
server:do_handle_string([=[
1918
response:code(200)
2019
response:write("OK\n")
2120
response:done()
22-
end
21+
]=])
2322
]]
2423
local server_plugin = plugin.do_string(server_body, addr_ch)
2524
server_plugin:run()

http/test/test_server.lua

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
local http = require 'http'
2+
local plugin = require 'plugin'
3+
local time = require 'time'
4+
local inspect = require 'inspect'
5+
6+
function Test_do_handle_function(t)
7+
local addr_ch = channel.make(1)
8+
local server_body = [[
9+
local addr_ch = unpack(arg)
10+
local http = require 'http'
11+
local server, err = http.server {}
12+
assert(not err, tostring(err))
13+
addr_ch:send(server:addr())
14+
server:do_handle_function(function(response, request)
15+
print(string.format("response = %s", response))
16+
response:write("OK\n")
17+
response:code(200)
18+
response:done()
19+
end)
20+
]]
21+
local server_plugin = plugin.do_string(server_body, addr_ch)
22+
server_plugin:run()
23+
time.sleep(1)
24+
local server_plugin_error = server_plugin:error()
25+
assert(not server_plugin_error, tostring(server_plugin_error))
26+
local ok, addr = addr_ch:receive(1)
27+
assert(ok, "addr not ok")
28+
local tURL = string.format("http://%s/", addr)
29+
30+
local client = http.client()
31+
local req = http.request('GET', tURL)
32+
local resp, err = client:do_request(req)
33+
assert(not err, tostring(err))
34+
assert(resp.code == 200, tostring(resp.code))
35+
end

0 commit comments

Comments
 (0)