|
15 | 15 | package integration
|
16 | 16 |
|
17 | 17 | import (
|
| 18 | + "io" |
| 19 | + "io/ioutil" |
18 | 20 | "math/rand"
|
| 21 | + "os" |
| 22 | + "path/filepath" |
19 | 23 | "testing"
|
20 | 24 | "time"
|
21 | 25 |
|
@@ -66,6 +70,89 @@ func TestDialTLSExpired(t *testing.T) {
|
66 | 70 | }
|
67 | 71 | }
|
68 | 72 |
|
| 73 | +// TestDialTLSExpiredReload ensures server reloads expired certs, |
| 74 | +// rejecting client requests, and vice versa. |
| 75 | +func TestDialTLSExpiredReload(t *testing.T) { |
| 76 | + defer testutil.AfterTest(t) |
| 77 | + |
| 78 | + ts, err := copyTLSFiles(testTLSInfo) |
| 79 | + if err != nil { |
| 80 | + t.Fatal(err) |
| 81 | + } |
| 82 | + certsDir := filepath.Dir(ts.KeyFile) |
| 83 | + defer os.RemoveAll(certsDir) |
| 84 | + |
| 85 | + tse, err := copyTLSFiles(testTLSInfoExpired) |
| 86 | + if err != nil { |
| 87 | + t.Fatal(err) |
| 88 | + } |
| 89 | + dir2 := filepath.Dir(tse.KeyFile) |
| 90 | + defer os.RemoveAll(dir2) |
| 91 | + |
| 92 | + var tmpDir string |
| 93 | + tmpDir, err = ioutil.TempDir(os.TempDir(), "fixtures") |
| 94 | + if err != nil { |
| 95 | + t.Fatal(err) |
| 96 | + } |
| 97 | + os.RemoveAll(tmpDir) |
| 98 | + defer os.RemoveAll(tmpDir) |
| 99 | + |
| 100 | + // start with valid certs |
| 101 | + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts}) |
| 102 | + defer clus.Terminate(t) |
| 103 | + |
| 104 | + // replace certs directory with expired ones |
| 105 | + if err = os.Rename(certsDir, tmpDir); err != nil { |
| 106 | + t.Fatal(err) |
| 107 | + } |
| 108 | + if err = os.Rename(dir2, certsDir); err != nil { |
| 109 | + t.Fatal(err) |
| 110 | + } |
| 111 | + |
| 112 | + // 'tmpDir' now has valid certs |
| 113 | + // 'certsDir' now has expired certs; 'dir2' does not exist |
| 114 | + |
| 115 | + // now server expects 'tls: bad certificate' |
| 116 | + // on incoming client requests |
| 117 | + tls, err := ts.ClientConfig() |
| 118 | + if err != nil { |
| 119 | + t.Fatal(err) |
| 120 | + } |
| 121 | + _, err = clientv3.New(clientv3.Config{ |
| 122 | + Endpoints: []string{clus.Members[0].GRPCAddr()}, |
| 123 | + DialTimeout: 3 * time.Second, |
| 124 | + TLS: tls, |
| 125 | + }) |
| 126 | + if err != grpc.ErrClientConnTimeout { |
| 127 | + t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, err) |
| 128 | + } |
| 129 | + |
| 130 | + // swap expired certs back with valid ones |
| 131 | + if err = os.Rename(tmpDir, dir2); err != nil { |
| 132 | + t.Fatal(err) |
| 133 | + } |
| 134 | + if err = os.Rename(certsDir, tmpDir); err != nil { |
| 135 | + t.Fatal(err) |
| 136 | + } |
| 137 | + if err = os.Rename(dir2, certsDir); err != nil { |
| 138 | + t.Fatal(err) |
| 139 | + } |
| 140 | + tls, err = ts.ClientConfig() |
| 141 | + if err != nil { |
| 142 | + t.Fatal(err) |
| 143 | + } |
| 144 | + var cl *clientv3.Client |
| 145 | + cl, err = clientv3.New(clientv3.Config{ |
| 146 | + Endpoints: []string{clus.Members[0].GRPCAddr()}, |
| 147 | + DialTimeout: 3 * time.Second, |
| 148 | + TLS: tls, |
| 149 | + }) |
| 150 | + defer cl.Close() |
| 151 | + if err != nil { |
| 152 | + t.Fatalf("expected no error, got %v", err) |
| 153 | + } |
| 154 | +} |
| 155 | + |
69 | 156 | // TestDialSetEndpoints ensures SetEndpoints can replace unavailable endpoints with available ones.
|
70 | 157 | func TestDialSetEndpointsBeforeFail(t *testing.T) {
|
71 | 158 | testDialSetEndpoints(t, true)
|
@@ -173,3 +260,52 @@ func TestDialForeignEndpoint(t *testing.T) {
|
173 | 260 | t.Fatal(err)
|
174 | 261 | }
|
175 | 262 | }
|
| 263 | + |
| 264 | +// copyTLSFiles clones certs files to temp directory. |
| 265 | +func copyTLSFiles(ti transport.TLSInfo) (transport.TLSInfo, error) { |
| 266 | + tmpdir, err := ioutil.TempDir(os.TempDir(), "fixtures") |
| 267 | + if err != nil { |
| 268 | + return transport.TLSInfo{}, err |
| 269 | + } |
| 270 | + ci := transport.TLSInfo{ |
| 271 | + KeyFile: filepath.Join(tmpdir, "server-key.pem"), |
| 272 | + CertFile: filepath.Join(tmpdir, "server.pem"), |
| 273 | + TrustedCAFile: filepath.Join(tmpdir, "etcd-root-ca.pem"), |
| 274 | + ClientCertAuth: ti.ClientCertAuth, |
| 275 | + } |
| 276 | + if err = copyFile(ti.KeyFile, ci.KeyFile); err != nil { |
| 277 | + return transport.TLSInfo{}, err |
| 278 | + } |
| 279 | + if err = copyFile(ti.CertFile, ci.CertFile); err != nil { |
| 280 | + return transport.TLSInfo{}, err |
| 281 | + } |
| 282 | + if err = copyFile(ti.TrustedCAFile, ci.TrustedCAFile); err != nil { |
| 283 | + return transport.TLSInfo{}, err |
| 284 | + } |
| 285 | + return ci, nil |
| 286 | +} |
| 287 | + |
| 288 | +func copyFile(src, dst string) error { |
| 289 | + f, err := os.Open(src) |
| 290 | + if err != nil { |
| 291 | + return err |
| 292 | + } |
| 293 | + defer f.Close() |
| 294 | + |
| 295 | + w, err := os.Create(dst) |
| 296 | + if err != nil { |
| 297 | + return err |
| 298 | + } |
| 299 | + defer w.Close() |
| 300 | + |
| 301 | + if _, err = io.Copy(w, f); err != nil { |
| 302 | + return err |
| 303 | + } |
| 304 | + if err = w.Sync(); err != nil { |
| 305 | + return err |
| 306 | + } |
| 307 | + if _, err = w.Seek(0, io.SeekStart); err != nil { |
| 308 | + return err |
| 309 | + } |
| 310 | + return nil |
| 311 | +} |
0 commit comments