Skip to content

Commit e2b40c4

Browse files
authored
fix: resolve data race on closing client (GoogleCloudPlatform#62)
This is a port of GoogleCloudPlatform/cloud-sql-proxy#1245
1 parent bca543c commit e2b40c4

File tree

4 files changed

+200
-82
lines changed

4 files changed

+200
-82
lines changed

cmd/root.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,11 @@ func runSignalWrapper(cmd *Command) error {
384384
case p = <-startCh:
385385
}
386386
cmd.Println("The proxy has started successfully and is ready for new connections!")
387-
defer p.Close()
387+
defer func() {
388+
if cErr := p.Close(); cErr != nil {
389+
cmd.PrintErrf("error during shutdown: %v\n", cErr)
390+
}
391+
}()
388392

389393
go func() {
390394
shutdownCh <- p.Serve(ctx)

cmd/root_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ func (*spyDialer) Close() error {
362362
}
363363

364364
func TestCommandWithCustomDialer(t *testing.T) {
365-
want := "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance"
365+
want := "projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance"
366366
s := &spyDialer{}
367367
c := NewCommand(WithDialer(s))
368368
// Keep the test output quiet

internal/proxy/proxy.go

Lines changed: 107 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -176,76 +176,21 @@ func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client,
176176
}
177177
}
178178

179-
pc := newPortConfig(conf.Port)
180179
var mnts []*socketMount
180+
pc := newPortConfig(conf.Port)
181181
for _, inst := range conf.Instances {
182-
var (
183-
// network is one of "tcp" or "unix"
184-
network string
185-
// address is either a TCP host port, or a Unix socket
186-
address string
187-
)
188-
// IF
189-
// a global Unix socket directory is NOT set AND
190-
// an instance-level Unix socket is NOT set
191-
// (e.g., I didn't set a Unix socket globally or for this instance)
192-
// OR
193-
// an instance-level TCP address or port IS set
194-
// (e.g., I'm overriding any global settings to use TCP for this
195-
// instance)
196-
// use a TCP listener.
197-
// Otherwise, use a Unix socket.
198-
if (conf.UnixSocket == "" && inst.UnixSocket == "") ||
199-
(inst.Addr != "" || inst.Port != 0) {
200-
network = "tcp"
201-
202-
a := conf.Addr
203-
if inst.Addr != "" {
204-
a = inst.Addr
205-
}
206-
207-
var np int
208-
switch {
209-
case inst.Port != 0:
210-
np = inst.Port
211-
case conf.Port != 0:
212-
np = pc.nextPort()
213-
default:
214-
np = pc.nextPort()
215-
}
216-
217-
address = net.JoinHostPort(a, fmt.Sprint(np))
218-
} else {
219-
network = "unix"
220-
221-
dir := conf.UnixSocket
222-
if dir == "" {
223-
dir = inst.UnixSocket
224-
}
225-
ud, err := UnixSocketDir(dir, inst.Name)
226-
if err != nil {
227-
return nil, err
228-
}
229-
// Create the parent directory that will hold the socket.
230-
if _, err := os.Stat(ud); err != nil {
231-
if err = os.Mkdir(ud, 0777); err != nil {
232-
return nil, err
233-
}
234-
}
235-
// use the Postgres-specific socket name
236-
address = filepath.Join(ud, ".s.PGSQL.5432")
237-
}
238-
239-
m := &socketMount{inst: inst.Name}
240-
addr, err := m.listen(ctx, network, address)
182+
m, err := newSocketMount(ctx, conf, pc, inst)
241183
if err != nil {
242184
for _, m := range mnts {
243-
m.close()
185+
mErr := m.Close()
186+
if mErr != nil {
187+
cmd.PrintErrf("failed to close mount: %v", mErr)
188+
}
244189
}
245190
return nil, fmt.Errorf("[%v] Unable to mount socket: %v", inst.Name, err)
246191
}
247192

248-
cmd.Printf("[%s] Listening on %s\n", inst.Name, addr.String())
193+
cmd.Printf("[%s] Listening on %s\n", inst.Name, m.Addr())
249194
mnts = append(mnts, m)
250195
}
251196

@@ -277,22 +222,45 @@ func (c *Client) Serve(ctx context.Context) error {
277222
return <-exitCh
278223
}
279224

280-
// Close triggers the proxyClient to shutdown.
281-
func (c *Client) Close() {
282-
defer c.dialer.Close()
225+
// MultiErr is a group of errors wrapped into one.
226+
type MultiErr []error
227+
228+
// Error returns a single string representing one or more errors.
229+
func (m MultiErr) Error() string {
230+
l := len(m)
231+
if l == 1 {
232+
return m[0].Error()
233+
}
234+
var errs []string
235+
for _, e := range m {
236+
errs = append(errs, e.Error())
237+
}
238+
return strings.Join(errs, ", ")
239+
}
240+
241+
func (c *Client) Close() error {
242+
var mErr MultiErr
283243
for _, m := range c.mnts {
284-
m.close()
244+
err := m.Close()
245+
if err != nil {
246+
mErr = append(mErr, err)
247+
}
248+
}
249+
cErr := c.dialer.Close()
250+
if cErr != nil {
251+
mErr = append(mErr, cErr)
252+
}
253+
if len(mErr) > 0 {
254+
return mErr
285255
}
256+
return nil
286257
}
287258

288259
// serveSocketMount persistently listens to the socketMounts listener and proxies connections to a
289260
// given AlloyDB instance.
290261
func (c *Client) serveSocketMount(ctx context.Context, s *socketMount) error {
291-
if s.listener == nil {
292-
return fmt.Errorf("[%s] mount doesn't have a listener set", s.inst)
293-
}
294262
for {
295-
cConn, err := s.listener.Accept()
263+
cConn, err := s.Accept()
296264
if err != nil {
297265
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
298266
c.cmd.PrintErrf("[%s] Error accepting connection: %v\n", s.inst, err)
@@ -327,22 +295,82 @@ type socketMount struct {
327295
listener net.Listener
328296
}
329297

330-
// listen causes a socketMount to create a Listener at the specified network address.
331-
func (s *socketMount) listen(ctx context.Context, network string, address string) (net.Addr, error) {
298+
func newSocketMount(ctx context.Context, conf *Config, pc *portConfig, inst InstanceConnConfig) (*socketMount, error) {
299+
var (
300+
// network is one of "tcp" or "unix"
301+
network string
302+
// address is either a TCP host port, or a Unix socket
303+
address string
304+
)
305+
// IF
306+
// a global Unix socket directory is NOT set AND
307+
// an instance-level Unix socket is NOT set
308+
// (e.g., I didn't set a Unix socket globally or for this instance)
309+
// OR
310+
// an instance-level TCP address or port IS set
311+
// (e.g., I'm overriding any global settings to use TCP for this
312+
// instance)
313+
// use a TCP listener.
314+
// Otherwise, use a Unix socket.
315+
if (conf.UnixSocket == "" && inst.UnixSocket == "") ||
316+
(inst.Addr != "" || inst.Port != 0) {
317+
network = "tcp"
318+
319+
a := conf.Addr
320+
if inst.Addr != "" {
321+
a = inst.Addr
322+
}
323+
324+
var np int
325+
switch {
326+
case inst.Port != 0:
327+
np = inst.Port
328+
default:
329+
np = pc.nextPort()
330+
}
331+
332+
address = net.JoinHostPort(a, fmt.Sprint(np))
333+
} else {
334+
network = "unix"
335+
336+
dir := conf.UnixSocket
337+
if dir == "" {
338+
dir = inst.UnixSocket
339+
}
340+
ud, err := UnixSocketDir(dir, inst.Name)
341+
if err != nil {
342+
return nil, err
343+
}
344+
// Create the parent directory that will hold the socket.
345+
if _, err := os.Stat(ud); err != nil {
346+
if err = os.Mkdir(ud, 0777); err != nil {
347+
return nil, err
348+
}
349+
}
350+
// use the Postgres-specific socket name
351+
address = filepath.Join(ud, ".s.PGSQL.5432")
352+
}
353+
332354
lc := net.ListenConfig{KeepAlive: 30 * time.Second}
333-
l, err := lc.Listen(ctx, network, address)
355+
ln, err := lc.Listen(ctx, network, address)
334356
if err != nil {
335357
return nil, err
336358
}
337-
s.listener = l
338-
return s.listener.Addr(), nil
359+
m := &socketMount{inst: inst.Name, listener: ln}
360+
return m, nil
361+
}
362+
363+
func (s *socketMount) Addr() net.Addr {
364+
return s.listener.Addr()
365+
}
366+
367+
func (s *socketMount) Accept() (net.Conn, error) {
368+
return s.listener.Accept()
339369
}
340370

341371
// close stops the mount from listening for any more connections
342-
func (s *socketMount) close() error {
343-
err := s.listener.Close()
344-
s.listener = nil
345-
return err
372+
func (s *socketMount) Close() error {
373+
return s.listener.Close()
346374
}
347375

348376
// proxyConn sets up a bidirectional copy between two open connections

internal/proxy/proxy_test.go

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ package proxy_test
1616

1717
import (
1818
"context"
19+
"errors"
1920
"io/ioutil"
2021
"net"
2122
"os"
2223
"path/filepath"
2324
"testing"
25+
"time"
2426

2527
"cloud.google.com/go/alloydbconn"
2628
"github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy"
@@ -37,13 +39,22 @@ type testCase struct {
3739
}
3840

3941
func (fakeDialer) Dial(ctx context.Context, inst string, opts ...alloydbconn.DialOption) (net.Conn, error) {
40-
return nil, nil
42+
conn, _ := net.Pipe()
43+
return conn, nil
4144
}
4245

4346
func (fakeDialer) Close() error {
4447
return nil
4548
}
4649

50+
type errorDialer struct {
51+
fakeDialer
52+
}
53+
54+
func (errorDialer) Close() error {
55+
return errors.New("errorDialer returns error on Close")
56+
}
57+
4758
func createTempDir(t *testing.T) (string, func()) {
4859
testDir, err := ioutil.TempDir("", "*")
4960
if err != nil {
@@ -216,6 +227,81 @@ func TestClientInitialization(t *testing.T) {
216227
}
217228
}
218229

230+
func TestClientClosesCleanly(t *testing.T) {
231+
in := &proxy.Config{
232+
Addr: "127.0.0.1",
233+
Port: 5000,
234+
Instances: []proxy.InstanceConnConfig{
235+
{Name: "proj:reg:inst"},
236+
},
237+
Dialer: fakeDialer{},
238+
}
239+
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
240+
if err != nil {
241+
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
242+
}
243+
go c.Serve(context.Background())
244+
time.Sleep(time.Second) // allow the socket to start listening
245+
246+
conn, dErr := net.Dial("tcp", "127.0.0.1:5000")
247+
if dErr != nil {
248+
t.Fatalf("net.Dial error = %v", dErr)
249+
}
250+
_ = conn.Close()
251+
252+
if err := c.Close(); err != nil {
253+
t.Fatalf("c.Close() error = %v", err)
254+
}
255+
}
256+
257+
func TestClosesWithError(t *testing.T) {
258+
in := &proxy.Config{
259+
Addr: "127.0.0.1",
260+
Port: 5000,
261+
Instances: []proxy.InstanceConnConfig{
262+
{Name: "proj:reg:inst"},
263+
},
264+
Dialer: errorDialer{},
265+
}
266+
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
267+
if err != nil {
268+
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
269+
}
270+
go c.Serve(context.Background())
271+
time.Sleep(time.Second) // allow the socket to start listening
272+
273+
if err = c.Close(); err == nil {
274+
t.Fatal("c.Close() should error, got nil")
275+
}
276+
}
277+
278+
func TestMultiErrorFormatting(t *testing.T) {
279+
tcs := []struct {
280+
desc string
281+
in proxy.MultiErr
282+
want string
283+
}{
284+
{
285+
desc: "with one error",
286+
in: proxy.MultiErr{errors.New("woops")},
287+
want: "woops",
288+
},
289+
{
290+
desc: "with many errors",
291+
in: proxy.MultiErr{errors.New("woops"), errors.New("another error")},
292+
want: "woops, another error",
293+
},
294+
}
295+
296+
for _, tc := range tcs {
297+
t.Run(tc.desc, func(t *testing.T) {
298+
if got := tc.in.Error(); got != tc.want {
299+
t.Errorf("want = %v, got = %v", tc.want, got)
300+
}
301+
})
302+
}
303+
}
304+
219305
func TestClientInitializationWorksRepeatedly(t *testing.T) {
220306
// The client creates a Unix socket on initial startup and does not remove
221307
// it on shutdown. This test ensures the existing socket does not cause

0 commit comments

Comments
 (0)