From 90c365e70d4cc5f75df682e112649612fb47eaba Mon Sep 17 00:00:00 2001 From: Gabriel Villalonga Simon Date: Mon, 12 Feb 2024 16:34:35 +0000 Subject: [PATCH] Wait for stdout pipe to close before calling runner.Wait() (#299) If the two goroutines are left to race each other, when runner.Wait() wins it will close the file and cause the stdout scanner to log a spurious os.ErrClosed error instead of returning nil after encountering an io.EOF error. --- client.go | 18 ++++++++++-------- client_test.go | 4 ++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index 73f6b35..b813ba8 100644 --- a/client.go +++ b/client.go @@ -104,9 +104,9 @@ type Client struct { // goroutines. clientWaitGroup sync.WaitGroup - // stderrWaitGroup is used to prevent the command's Wait() function from - // being called before we've finished reading from the stderr pipe. - stderrWaitGroup sync.WaitGroup + // pipesWaitGroup is used to prevent the command's Wait() function from + // being called before we've finished reading from the stdout and stderr pipe. + pipesWaitGroup sync.WaitGroup // processKilled is used for testing only, to flag when the process was // forcefully killed. @@ -756,8 +756,8 @@ func (c *Client) Start() (addr net.Addr, err error) { // Start goroutine that logs the stderr c.clientWaitGroup.Add(1) - c.stderrWaitGroup.Add(1) - // logStderr calls Done() + c.pipesWaitGroup.Add(1) + // logStderr calls c.pipesWaitGroup.Done() go c.logStderr(runner.Name(), runner.Stderr()) c.clientWaitGroup.Add(1) @@ -767,9 +767,9 @@ func (c *Client) Start() (addr net.Addr, err error) { defer c.clientWaitGroup.Done() - // wait to finish reading from stderr since the stderr pipe reader + // wait to finish reading from stdout/stderr since the stdout/stderr pipe readers // will be closed by the subsequent call to cmd.Wait(). - c.stderrWaitGroup.Wait() + c.pipesWaitGroup.Wait() // Wait for the command to end. err := runner.Wait(context.Background()) @@ -792,8 +792,10 @@ func (c *Client) Start() (addr net.Addr, err error) { // out of stdout linesCh := make(chan string) c.clientWaitGroup.Add(1) + c.pipesWaitGroup.Add(1) go func() { defer c.clientWaitGroup.Done() + defer c.pipesWaitGroup.Done() defer close(linesCh) scanner := bufio.NewScanner(runner.Stdout()) @@ -1159,7 +1161,7 @@ func (c *Client) getGRPCMuxer(addr net.Addr) (*grpcmux.GRPCClientMuxer, error) { func (c *Client) logStderr(name string, r io.Reader) { defer c.clientWaitGroup.Done() - defer c.stderrWaitGroup.Done() + defer c.pipesWaitGroup.Done() l := c.logger.Named(filepath.Base(name)) reader := bufio.NewReaderSize(r, c.config.PluginLogBufferSize) diff --git a/client_test.go b/client_test.go index a17bf10..c51a371 100644 --- a/client_test.go +++ b/client_test.go @@ -1503,7 +1503,7 @@ this line is short reader := strings.NewReader(msg) - c.stderrWaitGroup.Add(1) + c.pipesWaitGroup.Add(1) c.logStderr(c.config.Cmd.Path, reader) read := stderr.String() @@ -1531,7 +1531,7 @@ func TestClient_logStderrParseJSON(t *testing.T) { {"@message": "this is a large message that is more than 64 bytes long", "@level": "info"}` reader := strings.NewReader(msg) - c.stderrWaitGroup.Add(1) + c.pipesWaitGroup.Add(1) c.logStderr(c.config.Cmd.Path, reader) logs := strings.Split(strings.TrimSpace(logBuf.String()), "\n")