Skip to content

Commit bb730e7

Browse files
committed
clientv3/integration: test TLS reload
Signed-off-by: Gyu-Ho Lee <gyuhox@gmail.com>
1 parent 9408f66 commit bb730e7

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed

clientv3/integration/dial_test.go

+136
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
package integration
1616

1717
import (
18+
"io"
19+
"io/ioutil"
1820
"math/rand"
21+
"os"
22+
"path/filepath"
1923
"testing"
2024
"time"
2125

@@ -66,6 +70,89 @@ func TestDialTLSExpired(t *testing.T) {
6670
}
6771
}
6872

73+
// TestDialTLSExpiredReload ensures server reloads expired certs,
74+
// rejecting client requests, and vice versa.
75+
func TestDialTLSExpiredReload(t *testing.T) {
76+
defer testutil.AfterTest(t)
77+
78+
ts, err := copyTLSFiles(testTLSInfo)
79+
if err != nil {
80+
t.Fatal(err)
81+
}
82+
certsDir := filepath.Dir(ts.KeyFile)
83+
defer os.RemoveAll(certsDir)
84+
85+
tse, err := copyTLSFiles(testTLSInfoExpired)
86+
if err != nil {
87+
t.Fatal(err)
88+
}
89+
dir2 := filepath.Dir(tse.KeyFile)
90+
defer os.RemoveAll(dir2)
91+
92+
var tmpDir string
93+
tmpDir, err = ioutil.TempDir(os.TempDir(), "fixtures")
94+
if err != nil {
95+
t.Fatal(err)
96+
}
97+
os.RemoveAll(tmpDir)
98+
defer os.RemoveAll(tmpDir)
99+
100+
// start with valid certs
101+
clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts})
102+
defer clus.Terminate(t)
103+
104+
// replace certs directory with expired ones
105+
if err = os.Rename(certsDir, tmpDir); err != nil {
106+
t.Fatal(err)
107+
}
108+
if err = os.Rename(dir2, certsDir); err != nil {
109+
t.Fatal(err)
110+
}
111+
112+
// 'tmpDir' now has valid certs
113+
// 'certsDir' now has expired certs; 'dir2' does not exist
114+
115+
// now server expects 'tls: bad certificate'
116+
// on incoming client requests
117+
tls, err := ts.ClientConfig()
118+
if err != nil {
119+
t.Fatal(err)
120+
}
121+
_, err = clientv3.New(clientv3.Config{
122+
Endpoints: []string{clus.Members[0].GRPCAddr()},
123+
DialTimeout: 3 * time.Second,
124+
TLS: tls,
125+
})
126+
if err != grpc.ErrClientConnTimeout {
127+
t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, err)
128+
}
129+
130+
// swap expired certs back with valid ones
131+
if err = os.Rename(tmpDir, dir2); err != nil {
132+
t.Fatal(err)
133+
}
134+
if err = os.Rename(certsDir, tmpDir); err != nil {
135+
t.Fatal(err)
136+
}
137+
if err = os.Rename(dir2, certsDir); err != nil {
138+
t.Fatal(err)
139+
}
140+
tls, err = ts.ClientConfig()
141+
if err != nil {
142+
t.Fatal(err)
143+
}
144+
var cl *clientv3.Client
145+
cl, err = clientv3.New(clientv3.Config{
146+
Endpoints: []string{clus.Members[0].GRPCAddr()},
147+
DialTimeout: 3 * time.Second,
148+
TLS: tls,
149+
})
150+
defer cl.Close()
151+
if err != nil {
152+
t.Fatalf("expected no error, got %v", err)
153+
}
154+
}
155+
69156
// TestDialSetEndpoints ensures SetEndpoints can replace unavailable endpoints with available ones.
70157
func TestDialSetEndpointsBeforeFail(t *testing.T) {
71158
testDialSetEndpoints(t, true)
@@ -173,3 +260,52 @@ func TestDialForeignEndpoint(t *testing.T) {
173260
t.Fatal(err)
174261
}
175262
}
263+
264+
// copyTLSFiles clones certs files to temp directory.
265+
func copyTLSFiles(ti transport.TLSInfo) (transport.TLSInfo, error) {
266+
tmpdir, err := ioutil.TempDir(os.TempDir(), "fixtures")
267+
if err != nil {
268+
return transport.TLSInfo{}, err
269+
}
270+
ci := transport.TLSInfo{
271+
KeyFile: filepath.Join(tmpdir, "server-key.pem"),
272+
CertFile: filepath.Join(tmpdir, "server.pem"),
273+
TrustedCAFile: filepath.Join(tmpdir, "etcd-root-ca.pem"),
274+
ClientCertAuth: ti.ClientCertAuth,
275+
}
276+
if err = copyFile(ti.KeyFile, ci.KeyFile); err != nil {
277+
return transport.TLSInfo{}, err
278+
}
279+
if err = copyFile(ti.CertFile, ci.CertFile); err != nil {
280+
return transport.TLSInfo{}, err
281+
}
282+
if err = copyFile(ti.TrustedCAFile, ci.TrustedCAFile); err != nil {
283+
return transport.TLSInfo{}, err
284+
}
285+
return ci, nil
286+
}
287+
288+
func copyFile(src, dst string) error {
289+
f, err := os.Open(src)
290+
if err != nil {
291+
return err
292+
}
293+
defer f.Close()
294+
295+
w, err := os.Create(dst)
296+
if err != nil {
297+
return err
298+
}
299+
defer w.Close()
300+
301+
if _, err = io.Copy(w, f); err != nil {
302+
return err
303+
}
304+
if err = w.Sync(); err != nil {
305+
return err
306+
}
307+
if _, err = w.Seek(0, io.SeekStart); err != nil {
308+
return err
309+
}
310+
return nil
311+
}

0 commit comments

Comments
 (0)