Skip to content

Commit 00a4cdc

Browse files
committed
Test server TLS config factory
1 parent 7e0baaf commit 00a4cdc

23 files changed

+4285
-28
lines changed

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ lint: ## Lint
4747
golint $$(go list ./...) 2>&1
4848

4949
test: ## Test
50-
GO111MODULE=on go test -mod=vendor $(BUILD_FLAGS) -v ./...
50+
GO111MODULE=on go test -count=1 -mod=vendor $(BUILD_FLAGS) -v ./...
51+
52+
test.race: ## Test with race detection
53+
GO111MODULE=on go test -race -count=1 -mod=vendor $(BUILD_FLAGS) -v ./...
5154

5255
build: vet ## Build executable
5356
CGO_ENABLED=1 GO111MODULE=on go build -mod=vendor -o $(BINARY) $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" .

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ require (
88
github.com/pkg/errors v0.9.1
99
github.com/prometheus/client_golang v1.13.0
1010
github.com/prometheus/common v0.37.0
11-
github.com/stretchr/testify v1.7.1
11+
github.com/stretchr/testify v1.8.1
1212
github.com/sykesm/zap-logfmt v0.0.4
1313
go.uber.org/atomic v1.9.0
1414
go.uber.org/automaxprocs v1.5.1

go.sum

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd
253253
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
254254
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
255255
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
256+
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
257+
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
256258
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
257259
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
258260
github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@@ -261,6 +263,9 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
261263
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
262264
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
263265
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
266+
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
267+
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
268+
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
264269
github.com/sykesm/zap-logfmt v0.0.4 h1:U2WzRvmIWG1wDLCFY3sz8UeEmsdHQjHFNlIdmroVFaI=
265270
github.com/sykesm/zap-logfmt v0.0.4/go.mod h1:AuBd9xQjAe3URrWT1BBDk2v2onAZHkZkWRMiYZXiZWA=
266271
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=

pkg/tls/cert/filesource/filesource.go

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ const (
1717
)
1818

1919
type fileSource struct {
20-
certFile string
21-
keyFile string
22-
clientAuthFile string
23-
clientCRLFile string
24-
refresh time.Duration
25-
logger log.Logger
26-
20+
certFile string
21+
keyFile string
22+
clientAuthFile string
23+
clientCRLFile string
24+
refresh time.Duration
25+
logger log.Logger
26+
notifyFunc func()
2727
lastServerCerts atomic.Pointer[tlscert.ServerCerts]
2828
}
2929

@@ -48,6 +48,14 @@ func New(opts ...Option) (tlscert.ServerSource, error) {
4848
return s, nil
4949
}
5050

51+
func MustNew(opts ...Option) tlscert.ServerSource {
52+
serverSource, err := New(opts...)
53+
if err != nil {
54+
panic(`filesource: New(): ` + err.Error())
55+
}
56+
return serverSource
57+
}
58+
5159
func (s *fileSource) getServerCerts() (*tlscert.ServerCerts, error) {
5260
pemBlocks, err := s.Load()
5361
if err != nil {
@@ -91,10 +99,14 @@ func (s *fileSource) ServerCerts() chan tlscert.ServerCerts {
9199
if initialServerCert != nil {
92100
ch <- *initialServerCert
93101
}
94-
go func() {
95-
tlscert.Watch(s.logger, ch, s.refresh, initialServerCert, s.refreshServerCerts)
102+
if s.refresh <= 0 {
96103
close(ch)
97-
}()
104+
} else {
105+
go func() {
106+
tlscert.Watch(s.logger, ch, s.refresh, initialServerCert, s.refreshServerCerts, s.notifyFunc)
107+
close(ch)
108+
}()
109+
}
98110
return ch
99111
}
100112

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
package filesource
2+
3+
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
"github.com/grepplabs/mqtt-proxy/pkg/log"
7+
servertls "github.com/grepplabs/mqtt-proxy/pkg/tls"
8+
"github.com/stretchr/testify/require"
9+
"io"
10+
"net/http"
11+
"net/http/httptest"
12+
"net/url"
13+
"os"
14+
"testing"
15+
"time"
16+
)
17+
18+
func TestServerConfig(t *testing.T) {
19+
logger := log.GetInstance()
20+
bundle := newCertsBundle()
21+
defer bundle.Close()
22+
23+
tests := []struct {
24+
name string
25+
transportFunc func() http.RoundTripper
26+
configFunc func() *tls.Config
27+
requestError bool
28+
}{
29+
{
30+
name: "Client unknown authority",
31+
transportFunc: func() http.RoundTripper {
32+
return newRoundTripper()
33+
},
34+
configFunc: func() *tls.Config {
35+
return servertls.MustNewServerConfig(log.GetInstance(), MustNew(
36+
WithLogger(logger),
37+
WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()),
38+
))
39+
},
40+
requestError: true,
41+
},
42+
{
43+
name: "Client insecure",
44+
transportFunc: func() http.RoundTripper {
45+
return newRoundTripper(withClientTLSSkipVerify(true))
46+
},
47+
configFunc: func() *tls.Config {
48+
return servertls.MustNewServerConfig(log.GetInstance(), MustNew(
49+
WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()),
50+
))
51+
},
52+
},
53+
{
54+
name: "Client trusted CA",
55+
transportFunc: func() http.RoundTripper {
56+
return newRoundTripper(withRootCAs(bundle.CAX509Cert))
57+
},
58+
configFunc: func() *tls.Config {
59+
return servertls.MustNewServerConfig(log.GetInstance(), MustNew(
60+
WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()),
61+
))
62+
},
63+
},
64+
{
65+
name: "Client without required certificate",
66+
transportFunc: func() http.RoundTripper {
67+
return newRoundTripper(withRootCAs(bundle.CAX509Cert))
68+
},
69+
configFunc: func() *tls.Config {
70+
return servertls.MustNewServerConfig(log.GetInstance(), MustNew(
71+
WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()),
72+
WithClientAuthFile(bundle.CACert.Name()),
73+
))
74+
},
75+
requestError: true,
76+
},
77+
{
78+
name: "Client verification success",
79+
transportFunc: func() http.RoundTripper {
80+
return newRoundTripper(withRootCAs(bundle.CAX509Cert), withClientCertificate(bundle.ClientTLSCert))
81+
},
82+
configFunc: func() *tls.Config {
83+
return servertls.MustNewServerConfig(log.GetInstance(), MustNew(
84+
WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()),
85+
WithClientAuthFile(bundle.CACert.Name()),
86+
))
87+
},
88+
},
89+
{
90+
name: "Client verification success - empty CRL",
91+
transportFunc: func() http.RoundTripper {
92+
return newRoundTripper(withRootCAs(bundle.CAX509Cert), withClientCertificate(bundle.ClientTLSCert))
93+
},
94+
configFunc: func() *tls.Config {
95+
return servertls.MustNewServerConfig(log.GetInstance(), MustNew(
96+
WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()),
97+
WithClientAuthFile(bundle.CACert.Name()),
98+
WithClientCRLFile(bundle.CAEmptyCRL.Name()),
99+
))
100+
},
101+
},
102+
{
103+
name: "Client certificate revoked",
104+
transportFunc: func() http.RoundTripper {
105+
return newRoundTripper(withRootCAs(bundle.CAX509Cert), withClientCertificate(bundle.ClientTLSCert))
106+
},
107+
configFunc: func() *tls.Config {
108+
return servertls.MustNewServerConfig(log.GetInstance(), MustNew(
109+
WithX509KeyPair(bundle.ServerCert.Name(), bundle.ServerKey.Name()),
110+
WithClientAuthFile(bundle.CACert.Name()),
111+
WithClientCRLFile(bundle.ClientCRL.Name()),
112+
))
113+
},
114+
requestError: true,
115+
},
116+
}
117+
for _, tc := range tests {
118+
t.Run(tc.name, func(t *testing.T) {
119+
// given
120+
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
121+
w.WriteHeader(http.StatusOK)
122+
}))
123+
defer ts.Close()
124+
ts.TLS = tc.configFunc()
125+
ts.StartTLS()
126+
127+
httpClient := &http.Client{
128+
Transport: tc.transportFunc(),
129+
}
130+
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
131+
require.NoError(t, err)
132+
133+
// when
134+
res, err := httpClient.Do(req)
135+
136+
// then
137+
if tc.requestError {
138+
t.Log(err)
139+
require.NotNil(t, err)
140+
return
141+
}
142+
require.NoError(t, err)
143+
144+
_, err = io.ReadAll(res.Body)
145+
require.NoError(t, err)
146+
147+
_ = res.Body.Close()
148+
require.NoError(t, err)
149+
require.Equal(t, res.StatusCode, http.StatusOK)
150+
151+
})
152+
}
153+
}
154+
155+
func TestCertRotation(t *testing.T) {
156+
bundle1 := newCertsBundle()
157+
defer bundle1.Close()
158+
159+
bundle2 := newCertsBundle()
160+
defer bundle2.Close()
161+
162+
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
163+
w.WriteHeader(http.StatusOK)
164+
}))
165+
defer ts.Close()
166+
167+
rotatedCh := make(chan struct{}, 1)
168+
notifyFunc := func() {
169+
rotatedCh <- struct{}{}
170+
}
171+
source := MustNew(
172+
WithX509KeyPair(bundle1.ServerCert.Name(), bundle1.ServerKey.Name()),
173+
WithClientAuthFile(bundle1.CACert.Name()),
174+
WithClientCRLFile(bundle1.CAEmptyCRL.Name()),
175+
WithRefresh(1*time.Second),
176+
WithNotifyFunc(notifyFunc),
177+
).(*fileSource)
178+
179+
ts.TLS = servertls.MustNewServerConfig(log.GetInstance(), source)
180+
ts.StartTLS()
181+
182+
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
183+
require.NoError(t, err)
184+
185+
// when
186+
_, err = bundle1.newHttpClient().Do(req)
187+
require.NoError(t, err)
188+
189+
// copy new certificates to be used by server
190+
require.NoError(t, os.Rename(bundle2.ServerCert.Name(), bundle1.ServerCert.Name()))
191+
require.NoError(t, os.Rename(bundle2.ServerKey.Name(), bundle1.ServerKey.Name()))
192+
require.NoError(t, os.Rename(bundle2.CACert.Name(), bundle1.CACert.Name()))
193+
require.NoError(t, os.Rename(bundle2.CAEmptyCRL.Name(), bundle1.CAEmptyCRL.Name()))
194+
195+
select {
196+
case <-rotatedCh:
197+
t.Log("certificates were changed")
198+
time.Sleep(100 * time.Millisecond)
199+
case <-time.After(3 * time.Second):
200+
t.Fatal("expected certificate change notification")
201+
}
202+
// old client - bad certificate
203+
_, err = bundle1.newHttpClient().Do(req)
204+
require.NotNil(t, err)
205+
var unknownAuthorityError x509.UnknownAuthorityError
206+
require.ErrorAs(t, err.(*url.Error).Err, &unknownAuthorityError)
207+
208+
// new client - success
209+
_, err = bundle2.newHttpClient().Do(req)
210+
require.NoError(t, err)
211+
}

0 commit comments

Comments
 (0)