Skip to content

Commit

Permalink
Make write to connection cancellable
Browse files Browse the repository at this point in the history
There were a few issues with the original implementation:

- deadlineWriter didn't use a critical section around
  SetWriteDeadline/Write pair, so an incoming writer moved the deadline
  for pending writes.
- There was a lock around Write in the writeCoalescer implementation.
  This means that all goroutines trying to write to the connection could
  be stuck if Write blocks (i.e. when the TCP buffer is full, probably
  because of the remote node not reading fast enough).
- If a write was queued before the long Write call, it would be added to
  buffers (I'm not sure if an attempt would be made to write it before
  the connection is closed, but it seems possible).

When a connection is stuck in Write and there are other writes queueing
up, we want to abort them if the context is canceled and the write
waiting in queue was not started yet.

We can't cancel writes that are blocked in Write when the context
is canceled because context can be canceled due to external factors
like a user disconnecting. Canceling the pending Write could result
in partial write of a frame, clobbering the connection state.

Added checks for SetWriteDeadline errors, since not setting the
deadline could stuck the write goroutines potentially for infinitely
long, so it seems better to just return error.
It seems that SetWriteDeadline could fail only if the network connection
does not use a facility like epoll, which is highly unlikely. I found
checked Go code and as far as I can tell, only some file descriptors
other that network connections don't support the deadline.

Also added correct return values (written byte count) since returning
0 when an error is made is misleading.
  • Loading branch information
martin-sucha committed Mar 18, 2022
1 parent 344f583 commit 830d6d0
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 122 deletions.
289 changes: 189 additions & 100 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ var TimeoutLimit int64 = 0
type Conn struct {
conn net.Conn
r *bufio.Reader
w io.Writer
w contextWriter

timeout time.Duration
cfg *ConnConfig
Expand Down Expand Up @@ -286,9 +286,11 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *
streams: streams.New(cfg.ProtoVersion),
host: host,
frameObserver: s.frameObserver,
w: &deadlineWriter{
w: conn,
timeout: cfg.Timeout,
w: &deadlineContextWriter{
w: conn,
timeout: cfg.Timeout,
semaphore: make(chan struct{}, 1),
quit: make(chan struct{}),
},
ctx: ctx,
cancel: cancel,
Expand Down Expand Up @@ -340,7 +342,7 @@ func (c *Conn) init(ctx context.Context) error {
}

func (c *Conn) Write(p []byte) (n int, err error) {
return c.w.Write(p)
return c.w.writeContext(context.Background(), p)
}

func (c *Conn) Read(p []byte) (n int, err error) {
Expand Down Expand Up @@ -785,139 +787,210 @@ type callResp struct {
err error
}

type deadlineWriter struct {
w interface {
SetWriteDeadline(time.Time) error
io.Writer
}
// contextWriter is like io.Writer, but takes context as well.
type contextWriter interface {
// writeContext writes p to the connection.
//
// If ctx is canceled before we start writing p (e.g. during waiting while another write is currently in progress),
// p is not written and ctx.Err() is returned. Context is ignored after we start writing p (i.e. we don't interrupt
// blocked writes that are in progress) so that we always either write the full frame or not write it at all.
//
// It returns the number of bytes written from p (0 <= n <= len(p)) and any error that caused the write to stop
// early. writeContext must return a non-nil error if it returns n < len(p). writeContext must not modify the
// data in p, even temporarily.
writeContext(ctx context.Context, p []byte) (n int, err error)
}

type deadlineWriter interface {
SetWriteDeadline(time.Time) error
io.Writer
}

type deadlineContextWriter struct {
w deadlineWriter
timeout time.Duration
// semaphore protects critical section for SetWriteDeadline/Write.
// It is a channel with capacity 1.
semaphore chan struct{}

// quit closed once the connection is closed.
quit chan struct{}
}

func (c *deadlineWriter) Write(p []byte) (int, error) {
// writeContext implements contextWriter.
func (c *deadlineContextWriter) writeContext(ctx context.Context, p []byte) (int, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-c.quit:
return 0, ErrConnectionClosed
case c.semaphore <- struct{}{}:
// acquired
}

defer func() {
// release
<-c.semaphore
}()

if c.timeout > 0 {
c.w.SetWriteDeadline(time.Now().Add(c.timeout))
err := c.w.SetWriteDeadline(time.Now().Add(c.timeout))
if err != nil {
return 0, err
}
}
return c.w.Write(p)
}

func newWriteCoalescer(conn net.Conn, timeout time.Duration, d time.Duration, quit <-chan struct{}) *writeCoalescer {
func newWriteCoalescer(conn deadlineWriter, writeTimeout, coalesceDuration time.Duration,
quit <-chan struct{}) *writeCoalescer {
wc := &writeCoalescer{
writeCh: make(chan struct{}), // TODO: could this be sync?
cond: sync.NewCond(&sync.Mutex{}),
writeCh: make(chan writeRequest),
c: conn,
quit: quit,
timeout: timeout,
timeout: writeTimeout,
}
go wc.writeFlusher(d)
go wc.writeFlusher(coalesceDuration)
return wc
}

type writeCoalescer struct {
c net.Conn
c deadlineWriter

mu sync.Mutex

quit <-chan struct{}
writeCh chan struct{}
running bool
writeCh chan writeRequest

// cond waits for the buffer to be flushed
cond *sync.Cond
buffers net.Buffers
timeout time.Duration

// result of the write
err error
testEnqueuedHook func()
testFlushedHook func()
}

func (w *writeCoalescer) flushLocked() {
w.running = false
if len(w.buffers) == 0 {
return
}

if w.timeout > 0 {
w.c.SetWriteDeadline(time.Now().Add(w.timeout))
}

// Given we are going to do a fanout n is useless and according to
// the docs WriteTo should return 0 and err or bytes written and
// no error.
_, w.err = w.buffers.WriteTo(w.c)
if w.err != nil {
w.buffers = nil
}
w.cond.Broadcast()
type writeRequest struct {
// resultChan is a channel (with buffer size 1) where to send results of the write.
resultChan chan<- writeResult
// data to write.
data []byte
}

func (w *writeCoalescer) flush() {
w.cond.L.Lock()
w.flushLocked()
w.cond.L.Unlock()
}

func (w *writeCoalescer) stop() {
w.cond.L.Lock()
defer w.cond.L.Unlock()

w.flushLocked()
// nil the channel out sends block forever on it
// instead of closing which causes a send on closed channel
// panic.
w.writeCh = nil
type writeResult struct {
n int
err error
}

func (w *writeCoalescer) Write(p []byte) (int, error) {
w.cond.L.Lock()

if !w.running {
select {
case w.writeCh <- struct{}{}:
w.running = true
case <-w.quit:
w.cond.L.Unlock()
return 0, io.EOF // TODO: better error here?
}
// writeContext implements contextWriter.
func (w *writeCoalescer) writeContext(ctx context.Context, p []byte) (int, error) {
resultChan := make(chan writeResult, 1)
wr := writeRequest{
resultChan: resultChan,
data: p,
}

w.buffers = append(w.buffers, p)
for len(w.buffers) != 0 {
w.cond.Wait()
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-w.quit:
return 0, io.EOF // TODO: better error here?
case w.writeCh <- wr:
// enqueued for writing
}

err := w.err
w.cond.L.Unlock()

if err != nil {
return 0, err
if w.testEnqueuedHook != nil {
w.testEnqueuedHook()
}
return len(p), nil

result := <-resultChan
return result.n, result.err
}

func (w *writeCoalescer) writeFlusher(interval time.Duration) {
timer := time.NewTimer(interval)
defer timer.Stop()
defer w.stop()

if !timer.Stop() {
<-timer.C
}

w.writeFlusherImpl(timer.C, func() { timer.Reset(interval) })
}

func (w *writeCoalescer) writeFlusherImpl(timerC <-chan time.Time, resetTimer func()) {
running := false

var buffers net.Buffers
var resultChans []chan<- writeResult

for {
// wait for a write to start the flush loop
select {
case <-w.writeCh:
case req := <-w.writeCh:
buffers = append(buffers, req.data)
resultChans = append(resultChans, req.resultChan)
if !running {
// Start timer on first write.
resetTimer()
running = true
}
case <-w.quit:
result := writeResult{
n: 0,
err: io.EOF, // TODO: better error here?
}
// Unblock whoever was waiting.
for _, resultChan := range resultChans {
// resultChan has capacity 1, so it does not block.
resultChan <- result
}
return
case <-timerC:
running = false
w.flush(resultChans, buffers)
buffers = nil
resultChans = nil
if w.testFlushedHook != nil {
w.testFlushedHook()
}
}
}
}

timer.Reset(interval)

select {
case <-w.quit:
func (w *writeCoalescer) flush(resultChans []chan<- writeResult, buffers net.Buffers) {
// Flush everything we have so far.
if w.timeout > 0 {
err := w.c.SetWriteDeadline(time.Now().Add(w.timeout))
if err != nil {
for i := range resultChans {
resultChans[i] <- writeResult{
n: 0,
err: err,
}
}
return
case <-timer.C:
}

w.flush()
}
// Copy buffers because WriteTo modifies buffers in-place.
buffers2 := make(net.Buffers, len(buffers))
copy(buffers2, buffers)
n, err := buffers2.WriteTo(w.c)
// Writes of bytes before n succeeded, writes of bytes starting from n failed with err.
// Use n as remaining byte counter.
for i := range buffers {
if int64(len(buffers[i])) <= n {
// this buffer was fully written.
resultChans[i] <- writeResult{
n: len(buffers[i]),
err: nil,
}
n -= int64(len(buffers[i]))
} else {
// this buffer was not (fully) written.
resultChans[i] <- writeResult{
n: int(n),
err: err,
}
n = 0
}
}
}

Expand Down Expand Up @@ -969,23 +1042,39 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram
err := req.buildFrame(framer, stream)
if err != nil {
// We failed to serialize the frame into a buffer.
// This should not affect the connection, we just free the current call.
// This should not affect the connection as we didn't write anything. We just free the current call.
c.mu.Lock()
delete(c.calls, call.streamID)
c.mu.Unlock()
// We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil
// check above could fail.
c.releaseStream(call)
return nil, err
}

err = framer.writeTo(c)
n, err := c.w.writeContext(ctx, framer.buf)
if err != nil {
// closeWithError will block waiting for this stream to either receive a response
// or for us to timeout, close the timeout chan here. Im not entirely sure
// but we should not get a response after an error on the write side.
close(call.timeout)
// I think this is the correct thing to do, im not entirely sure. It is not
// ideal as readers might still get some data, but they probably wont.
// Here we need to be careful as the stream is not available and if all
// writes just timeout or fail then the pool might use this connection to
// send a frame on, with all the streams used up and not returned.
c.closeWithError(err)
if (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) && n == 0 {
// We have not started to write this frame.
// Release the stream as no response can come from the server on the stream.
c.mu.Lock()
delete(c.calls, call.streamID)
c.mu.Unlock()
// We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil
// check above could fail.
c.releaseStream(call)
} else {
// closeWithError will block waiting for this stream to either receive a response
// or for us to timeout, close the timeout chan here. Im not entirely sure
// but we should not get a response after an error on the write side.
close(call.timeout)
// I think this is the correct thing to do, im not entirely sure. It is not
// ideal as readers might still get some data, but they probably wont.
// Here we need to be careful as the stream is not available and if all
// writes just timeout or fail then the pool might use this connection to
// send a frame on, with all the streams used up and not returned.
c.closeWithError(err)
}
return nil, err
}

Expand Down
Loading

0 comments on commit 830d6d0

Please sign in to comment.