diff --git a/internal/proxy/fuse_test.go b/internal/proxy/fuse_test.go index 2068e4618..269ea290c 100644 --- a/internal/proxy/fuse_test.go +++ b/internal/proxy/fuse_test.go @@ -43,24 +43,37 @@ func randTmpDir(t interface { // newTestClient is a convenience function for testing that creates a // proxy.Client and starts it. The returned cleanup function is also a // convenience. Callers may choose to ignore it and manually close the client. -func newTestClient(t *testing.T, d cloudsql.Dialer, fuseDir, fuseTempDir string) (*proxy.Client, func()) { +func newTestClient(t *testing.T, d cloudsql.Dialer, fuseDir, fuseTempDir string) (*proxy.Client, chan error, func()) { conf := &proxy.Config{FUSEDir: fuseDir, FUSETempDir: fuseTempDir} + + // This context is only used to call the Cloud SQL API c, err := proxy.NewClient(context.Background(), d, testLogger, conf) if err != nil { t.Fatalf("want error = nil, got = %v", err) } ready := make(chan struct{}) - go c.Serve(context.Background(), func() { close(ready) }) + servErrCh := make(chan error) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + + servErr := c.Serve(ctx, func() { close(ready) }) + select { + case servErrCh <- servErr: + case <-ctx.Done(): + } + }() select { case <-ready: case <-time.Tick(5 * time.Second): t.Fatal("failed to Serve") } - return c, func() { + return c, servErrCh, func() { if cErr := c.Close(); cErr != nil { t.Logf("failed to close client: %v", cErr) } + cancel() } } @@ -70,7 +83,7 @@ func TestFUSEREADME(t *testing.T) { } dir := randTmpDir(t) d := &fakeDialer{} - _, cleanup := newTestClient(t, d, dir, randTmpDir(t)) + _, _, cleanup := newTestClient(t, d, dir, randTmpDir(t)) fi, err := os.Stat(dir) if err != nil { @@ -161,7 +174,7 @@ func TestFUSEDialInstance(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { d := &fakeDialer{} - _, cleanup := newTestClient(t, d, fuseDir, tc.fuseTempDir) + _, _, cleanup := newTestClient(t, d, fuseDir, tc.fuseTempDir) defer cleanup() conn := tryDialUnix(t, tc.socketPath) @@ -185,13 +198,66 @@ func TestFUSEDialInstance(t *testing.T) { }) } } +func TestFUSEAcceptErrorReturnedFromServe(t *testing.T) { + if testing.Short() { + t.Skip("skipping fuse tests in short mode.") + } + + fuseDir := randTmpDir(t) + fuseTempDir := randTmpDir(t) + socketPath := filepath.Join(fuseDir, "proj:region:mysql") + + // Create a new client + d := &fakeDialer{} + c, servErrCh, cleanup := newTestClient(t, d, fuseDir, fuseTempDir) + defer cleanup() + + // Attempt a successful connection to the client + conn := tryDialUnix(t, socketPath) + defer conn.Close() + + // Ensure that the client actually fully connected. + // This solves a race condition in the test that is only present on + // the Ubuntu-Latest platform. + var got []string + for i := 0; i < 10; i++ { + got = d.dialedInstances() + if len(got) == 1 { + break + } + time.Sleep(100 * time.Millisecond) + } + if len(got) != 1 { + t.Fatalf("dialed instances len: want = 1, got = %v", got) + } + + // Explicitly close the dialer. This will close all the unix sockets, forcing + // the unix socket accept goroutine to exit with an error + c.Close() + + // Check that Client.Serve() returned a non-nil error + for i := 0; i < 10; i++ { + select { + case servErr := <-servErrCh: + if servErr == nil { + t.Fatal("got nil, want non-nil error returned by Client.Serve()") + } + return + default: + time.Sleep(100 * time.Millisecond) + continue + } + } + t.Fatal("No error thrown by Client.Serve()") + +} func TestFUSEReadDir(t *testing.T) { if testing.Short() { t.Skip("skipping fuse tests in short mode.") } fuseDir := randTmpDir(t) - _, cleanup := newTestClient(t, &fakeDialer{}, fuseDir, randTmpDir(t)) + _, _, cleanup := newTestClient(t, &fakeDialer{}, fuseDir, randTmpDir(t)) defer cleanup() // Initiate a connection so the FUSE server will list it in the dir entries. @@ -221,7 +287,7 @@ func TestFUSEErrors(t *testing.T) { } ctx := context.Background() d := &fakeDialer{} - c, _ := newTestClient(t, d, randTmpDir(t), randTmpDir(t)) + c, _, _ := newTestClient(t, d, randTmpDir(t), randTmpDir(t)) // Simulate FUSE file access by invoking Lookup directly to control // how the socket cache is populated. @@ -261,7 +327,7 @@ func TestFUSEWithBadInstanceName(t *testing.T) { } fuseDir := randTmpDir(t) d := &fakeDialer{} - _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) + _, _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) defer cleanup() _, dialErr := net.Dial("unix", filepath.Join(fuseDir, "notvalid")) @@ -280,7 +346,7 @@ func TestFUSECheckConnections(t *testing.T) { } fuseDir := randTmpDir(t) d := &fakeDialer{} - c, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) + c, _, cleanup := newTestClient(t, d, fuseDir, randTmpDir(t)) defer cleanup() // first establish a connection to "register" it with the proxy @@ -315,7 +381,7 @@ func TestFUSEClose(t *testing.T) { } fuseDir := randTmpDir(t) d := &fakeDialer{} - c, _ := newTestClient(t, d, fuseDir, randTmpDir(t)) + c, _, _ := newTestClient(t, d, fuseDir, randTmpDir(t)) // first establish a connection to "register" it with the proxy conn := tryDialUnix(t, filepath.Join(fuseDir, "proj:reg:mysql")) diff --git a/internal/proxy/proxy_other.go b/internal/proxy/proxy_other.go index bac097ca4..47aa3bdc9 100644 --- a/internal/proxy/proxy_other.go +++ b/internal/proxy/proxy_other.go @@ -68,6 +68,7 @@ type fuseMount struct { fuseServerMu *sync.Mutex fuseServer *fuse.Server fuseWg *sync.WaitGroup + fuseExitCh chan error // Inode adds support for FUSE operations. fs.Inode @@ -131,10 +132,20 @@ func (c *Client) Lookup(ctx context.Context, instance string, _ *fuse.EntryOut) defer c.fuseWg.Done() sErr := c.serveSocketMount(ctx, s) if sErr != nil { - c.fuseMu.Lock() c.logger.Debugf("could not serve socket for instance %q: %v", instance, sErr) + c.fuseMu.Lock() + defer c.fuseMu.Unlock() delete(c.fuseSockets, instance) - c.fuseMu.Unlock() + select { + // Best effort attempt to send error. + // If this send fails, it means the reading goroutine has + // already pulled a value out of the channel and is no longer + // reading any more values. In other words, we report only the + // first error. + case c.fuseExitCh <- sErr: + default: + return + } } }() @@ -165,10 +176,16 @@ func (c *Client) serveFuse(ctx context.Context, notify func()) error { } c.fuseServerMu.Lock() c.fuseServer = srv + c.fuseExitCh = make(chan error) + c.fuseServerMu.Unlock() notify() - <-ctx.Done() - return ctx.Err() + select { + case err = <-c.fuseExitCh: + return err + case <-ctx.Done(): + return ctx.Err() + } } func (c *Client) fuseMounts() []*socketMount {