Skip to content

Commit 2e43447

Browse files
committed
Also add server mtls handling code
1 parent 0defb30 commit 2e43447

File tree

6 files changed

+151
-15
lines changed

6 files changed

+151
-15
lines changed

http/api_test.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ func TestApi(t *testing.T) {
225225
})
226226
}
227227

228-
func TestMTLS(t *testing.T) {
228+
func TestMTLSClient(t *testing.T) {
229229
s := httptest.NewUnstartedServer(http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) {
230230
_, _ = io.WriteString(writer, "OK\n")
231231
}))
@@ -250,5 +250,15 @@ func TestMTLS(t *testing.T) {
250250
L.SetGlobal("tURL", lua.LString(s.URL))
251251
},
252252
)
253-
assert.NotZero(t, tests.RunLuaTestFile(t, preload, "test/test_api.lua"))
253+
assert.NotZero(t, tests.RunLuaTestFile(t, preload, "test/test_mtls_client.lua"))
254+
}
255+
256+
func TestMTLSServerWithClient(t *testing.T) {
257+
preload := tests.SeveralPreloadFuncs(
258+
lua_http.Preload,
259+
lua_time.Preload,
260+
inspect.Preload,
261+
plugin.Preload,
262+
)
263+
assert.NotZero(t, tests.RunLuaTestFile(t, preload, "test/test_mtls_server_with_client.lua"))
254264
}

http/loader.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
// Preload adds http to the given Lua state's package.preload table. After it
1313
// has been preloaded, it can be loaded using require:
1414
//
15-
// local http = require("http")
15+
// local http = require("http")
1616
func Preload(L *lua.LState) {
1717
L.PreloadModule("http", Loader)
1818
client.Preload(L)
@@ -50,6 +50,7 @@ func Loader(L *lua.LState) int {
5050
L.SetGlobal(`http_server_ud`, http_server_ud)
5151
L.SetField(http_server_ud, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
5252
"accept": server.Accept,
53+
"addr": server.Addr,
5354
"do_handle_file": server.HandleFile,
5455
"do_handle_string": server.HandleString,
5556
}))

http/server/loader.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
// Preload adds http_server to the given Lua state's package.preload table. After it
1010
// has been preloaded, it can be loaded using require:
1111
//
12-
// local http_server = require("http_server")
12+
// local http_server = require("http_server")
1313
func Preload(L *lua.LState) {
1414
L.PreloadModule("http_server", Loader)
1515
}
@@ -31,6 +31,7 @@ func Loader(L *lua.LState) int {
3131
L.SetGlobal(`http_server_ud`, http_server_ud)
3232
L.SetField(http_server_ud, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
3333
"accept": Accept,
34+
"addr": Addr,
3435
"do_handle_file": HandleFile,
3536
"do_handle_string": HandleString,
3637
}))

http/server/server.go

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package http
22

33
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
ioutil2 "io/ioutil"
47
"log"
58
"net"
69
"net/http"
@@ -82,13 +85,70 @@ func (s *luaServer) serve(L *lua.LState) {
8285

8386
// http.server(bind, handler) returns (user data, error)
8487
func New(L *lua.LState) int {
85-
bind := L.CheckAny(1).String()
88+
var tlsConfig *tls.Config
89+
bind := "127.0.0.1:0"
90+
switch bindOrTable := L.CheckAny(1).(type) {
91+
case lua.LString:
92+
bind = string(bindOrTable)
93+
case *lua.LTable:
94+
if addr, ok := L.GetField(bindOrTable, "addr").(lua.LString); ok {
95+
bind = string(addr)
96+
}
97+
serverPublicCertPEMFile := L.GetField(bindOrTable, `server_public_cert_pem_file`)
98+
serverPrivateKeyPemFile := L.GetField(bindOrTable, `server_private_key_pem_file`)
99+
if serverPublicCertPEMFile != lua.LNil && serverPrivateKeyPemFile != lua.LNil {
100+
serverCert, err := tls.LoadX509KeyPair(serverPublicCertPEMFile.String(), serverPrivateKeyPemFile.String())
101+
if err != nil {
102+
L.RaiseError("error loading server cert: %v", err)
103+
}
104+
tlsConfig = &tls.Config{
105+
Certificates: []tls.Certificate{serverCert},
106+
}
107+
108+
clientAuth := L.GetField(bindOrTable, "client_auth")
109+
if clientAuth != lua.LNil {
110+
if _, ok := clientAuth.(lua.LString); !ok {
111+
L.ArgError(1, "client_auth should be a string")
112+
}
113+
switch clientAuth.String() {
114+
case "NoClientCert":
115+
tlsConfig.ClientAuth = tls.NoClientCert
116+
case "RequestClientCert":
117+
tlsConfig.ClientAuth = tls.RequestClientCert
118+
case "RequireAnyClientCert":
119+
tlsConfig.ClientAuth = tls.RequireAnyClientCert
120+
case "VerifyClientCertIfGiven":
121+
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
122+
case "RequireAndVerifyClientCert":
123+
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
124+
}
125+
}
126+
127+
clientCAs := L.GetField(bindOrTable, "client_cas_pem_file")
128+
if clientCAs != lua.LNil {
129+
if _, ok := clientCAs.(lua.LString); !ok {
130+
L.ArgError(1, "client_cas_pem_file must be a string")
131+
}
132+
data, err := ioutil2.ReadFile(clientCAs.String())
133+
if err != nil {
134+
L.RaiseError("error reading %s: %v", clientCAs, err)
135+
}
136+
tlsConfig.ClientCAs = x509.NewCertPool()
137+
if !tlsConfig.ClientCAs.AppendCertsFromPEM(data) {
138+
L.RaiseError("no certs loaded from %s", clientCAs)
139+
}
140+
}
141+
}
142+
}
86143
l, err := net.Listen(`tcp`, bind)
87144
if err != nil {
88145
L.Push(lua.LNil)
89146
L.Push(lua.LString(err.Error()))
90147
return 2
91148
}
149+
if tlsConfig != nil {
150+
l = tls.NewListener(l, tlsConfig)
151+
}
92152
server := &luaServer{
93153
Listener: l,
94154
serveData: make(chan *serveData, 1),
@@ -112,6 +172,13 @@ func Accept(L *lua.LState) int {
112172
}
113173
}
114174

175+
// Addr returns the address if, for instance, one listens on :0
176+
func Addr(L *lua.LState) int {
177+
s := checkServer(L, 1)
178+
L.Push(lua.LString(s.Listener.Addr().String()))
179+
return 1
180+
}
181+
115182
func newHandlerState(data *serveData) *lua.LState {
116183
state := lua.NewState()
117184

@@ -185,16 +252,18 @@ func HandleFile(L *lua.LState) int {
185252
func HandleString(L *lua.LState) int {
186253
s := checkServer(L, 1)
187254
body := L.CheckString(2)
188-
select {
189-
case data := <-s.serveData:
190-
go func(sData *serveData, content string) {
191-
state := newHandlerState(sData)
192-
if err := state.DoString(content); err != nil {
193-
log.Printf("[ERROR] handle: %s\n", err.Error())
194-
data.done <- true
195-
log.Printf("[ERROR] closed connection\n")
196-
}
197-
}(data, body)
255+
for {
256+
select {
257+
case data := <-s.serveData:
258+
go func(sData *serveData, content string) {
259+
state := newHandlerState(sData)
260+
if err := state.DoString(content); err != nil {
261+
log.Printf("[ERROR] handle: %s\n", err.Error())
262+
data.done <- true
263+
log.Printf("[ERROR] closed connection\n")
264+
}
265+
}(data, body)
266+
}
198267
}
199268
return 0
200269
}
File renamed without changes.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
local http = require 'http'
2+
local plugin = require 'plugin'
3+
local time = require 'time'
4+
local inspect = require 'inspect'
5+
6+
function TestMTLSServerWithClient(t)
7+
local addr_ch = channel.make(1)
8+
local server_body = [[
9+
local addr_ch = unpack(arg)
10+
local http = require 'http'
11+
local server, err = http.server {
12+
server_public_cert_pem_file = "test/data/test.cert.pem",
13+
server_private_key_pem_file = "test/data/test.key.pem",
14+
}
15+
assert(not err, tostring(err))
16+
addr_ch:send(server:addr())
17+
while true do
18+
local request, response = server:accept()
19+
response:code(200)
20+
response:write("OK\n")
21+
response:done()
22+
end
23+
]]
24+
local server_plugin = plugin.do_string(server_body, addr_ch)
25+
server_plugin:run()
26+
time.sleep(1)
27+
local server_plugin_error = server_plugin:error()
28+
assert(not server_plugin_error, tostring(server_plugin_error))
29+
local ok, addr = addr_ch:receive(1)
30+
assert(ok, "addr not ok")
31+
local tURL = string.format("https://%s/", addr)
32+
33+
t:Run('no-client-cert fails', function(t)
34+
local client = http.client()
35+
local req, err = http.request("GET", tURL)
36+
assert(not err, tostring(err))
37+
local resp, err = client:do_request(req)
38+
assert(err, tostring(err))
39+
end)
40+
41+
t:Run('client-cert passes', function(t)
42+
local client = http.client {
43+
root_cas_pem_file = 'test/data/test.cert.pem',
44+
client_public_cert_pem_file = 'test/data/test.cert.pem',
45+
client_private_key_pem_file = 'test/data/test.key.pem',
46+
}
47+
local req, err = http.request("GET", tURL)
48+
assert(not err, tostring(err))
49+
local resp, err = client:do_request(req)
50+
assert(not err, tostring(err))
51+
assert(resp.code == 200, tostring(resp.code))
52+
end)
53+
54+
server_plugin:stop()
55+
end

0 commit comments

Comments
 (0)